diff --git a/cmd/backup.go b/cmd/backup.go index 73ae6106f..d65582949 100644 --- a/cmd/backup.go +++ b/cmd/backup.go @@ -1,6 +1,7 @@ package cmd import ( + "bytes" "context" "github.com/pingcap/errors" @@ -25,6 +26,16 @@ const ( flagLastBackupTS = "lastbackupts" ) +type backupContext struct { + db string + table string + + isRawKv bool + startKey []byte + endKey []byte + cf string +} + func defineBackupFlags(flagSet *pflag.FlagSet) { flagSet.StringP( flagBackupTimeago, "", "", @@ -44,7 +55,7 @@ func defineBackupFlags(flagSet *pflag.FlagSet) { _ = flagSet.MarkHidden(flagBackupRateLimitUnit) } -func runBackup(flagSet *pflag.FlagSet, cmdName, db, table string) error { +func runBackup(flagSet *pflag.FlagSet, cmdName string, bc backupContext) error { ctx, cancel := context.WithCancel(defaultContext) defer cancel() @@ -110,10 +121,18 @@ func runBackup(flagSet *pflag.FlagSet, cmdName, db, table string) error { defer summary.Summary(cmdName) - ranges, backupSchemas, err := backup.BuildBackupRangeAndSchema( - mgr.GetDomain(), mgr.GetTiKV(), backupTS, db, table) - if err != nil { - return err + var ( + ranges []backup.Range + backupSchemas *backup.Schemas + ) + if bc.isRawKv { + ranges = []backup.Range{{StartKey: bc.startKey, EndKey: bc.endKey}} + } else { + ranges, backupSchemas, err = backup.BuildBackupRangeAndSchema( + mgr.GetDomain(), mgr.GetTiKV(), backupTS, bc.db, bc.table) + if err != nil { + return err + } } // The number of regions need to backup @@ -133,38 +152,39 @@ func runBackup(flagSet *pflag.FlagSet, cmdName, db, table string) error { updateCh := utils.StartProgress( ctx, cmdName, int64(approximateRegions), !HasLogFile()) err = client.BackupRanges( - ctx, ranges, lastBackupTS, backupTS, ratelimit, concurrency, updateCh) + ctx, ranges, lastBackupTS, backupTS, ratelimit, concurrency, updateCh, bc.isRawKv, bc.cf) if err != nil { return err } // Backup has finished close(updateCh) - // Checksum - backupSchemasConcurrency := backup.DefaultSchemaConcurrency - if backupSchemas.Len() < backupSchemasConcurrency { - backupSchemasConcurrency = backupSchemas.Len() - } - updateCh = utils.StartProgress( - ctx, "Checksum", int64(backupSchemas.Len()), !HasLogFile()) - backupSchemas.SetSkipChecksum(!checksum) - backupSchemas.Start( - ctx, mgr.GetTiKV(), backupTS, uint(backupSchemasConcurrency), updateCh) - - err = client.CompleteMeta(backupSchemas) - if err != nil { - return err - } + if backupSchemas != nil { + // Checksum + backupSchemasConcurrency := backup.DefaultSchemaConcurrency + if backupSchemas.Len() < backupSchemasConcurrency { + backupSchemasConcurrency = backupSchemas.Len() + } + updateCh = utils.StartProgress( + ctx, "Checksum", int64(backupSchemas.Len()), !HasLogFile()) + backupSchemas.SetSkipChecksum(!checksum) + backupSchemas.Start( + ctx, mgr.GetTiKV(), backupTS, uint(backupSchemasConcurrency), updateCh) - valid, err := client.FastChecksum() - if err != nil { - return err - } - if !valid { - log.Error("backup FastChecksum failed!") + err = client.CompleteMeta(backupSchemas) + if err != nil { + return err + } + valid, err := client.FastChecksum() + if err != nil { + return err + } + if !valid { + log.Error("backup FastChecksum failed!") + } + // Checksum has finished + close(updateCh) } - // Checksum has finished - close(updateCh) err = client.SaveBackupMeta(ctx) if err != nil { @@ -198,6 +218,7 @@ func NewBackupCommand() *cobra.Command { newFullBackupCommand(), newDbBackupCommand(), newTableBackupCommand(), + newRawBackupCommand(), ) defineBackupFlags(command.PersistentFlags()) @@ -211,7 +232,8 @@ func newFullBackupCommand() *cobra.Command { Short: "backup all database", RunE: func(command *cobra.Command, _ []string) error { // empty db/table means full backup. - return runBackup(command.Flags(), "Full backup", "", "") + bc := backupContext{db: "", table: "", isRawKv: false} + return runBackup(command.Flags(), "Full backup", bc) }, } return command @@ -230,7 +252,8 @@ func newDbBackupCommand() *cobra.Command { if len(db) == 0 { return errors.Errorf("empty database name is not allowed") } - return runBackup(command.Flags(), "Database backup", db, "") + bc := backupContext{db: db, table: "", isRawKv: false} + return runBackup(command.Flags(), "Database backup", bc) }, } command.Flags().StringP(flagDatabase, "", "", "backup a table in the specific db") @@ -259,7 +282,8 @@ func newTableBackupCommand() *cobra.Command { if len(table) == 0 { return errors.Errorf("empty table name is not allowed") } - return runBackup(command.Flags(), "Table backup", db, table) + bc := backupContext{db: db, table: table, isRawKv: false} + return runBackup(command.Flags(), "Table backup", bc) }, } command.Flags().StringP(flagDatabase, "", "", "backup a table in the specific db") @@ -268,3 +292,45 @@ func newTableBackupCommand() *cobra.Command { _ = command.MarkFlagRequired(flagTable) return command } + +// newRawBackupCommand return a raw kv range backup subcommand. +func newRawBackupCommand() *cobra.Command { + command := &cobra.Command{ + Use: "raw", + Short: "backup a raw kv range from TiKV cluster", + RunE: func(command *cobra.Command, _ []string) error { + start, err := command.Flags().GetString("start") + if err != nil { + return err + } + startKey, err := utils.ParseKey(command.Flags(), start) + if err != nil { + return err + } + end, err := command.Flags().GetString("end") + if err != nil { + return err + } + endKey, err := utils.ParseKey(command.Flags(), end) + if err != nil { + return err + } + + cf, err := command.Flags().GetString("cf") + if err != nil { + return err + } + + if bytes.Compare(startKey, endKey) > 0 { + return errors.New("input endKey must greater or equal than startKey") + } + bc := backupContext{startKey: startKey, endKey: endKey, isRawKv: true, cf: cf} + return runBackup(command.Flags(), "Raw Backup", bc) + }, + } + command.Flags().StringP("format", "", "raw", "raw key format") + command.Flags().StringP("cf", "", "default", "backup raw kv cf") + command.Flags().StringP("start", "", "", "backup raw kv start key") + command.Flags().StringP("end", "", "", "backup raw kv end key") + return command +} diff --git a/cmd/restore.go b/cmd/restore.go index 4f66e47de..4da48f036 100644 --- a/cmd/restore.go +++ b/cmd/restore.go @@ -8,6 +8,7 @@ import ( "github.com/pingcap/errors" "github.com/pingcap/kvproto/pkg/backup" "github.com/pingcap/log" + restore_util "github.com/pingcap/tidb-tools/pkg/restore-util" "github.com/pingcap/tidb/session" "github.com/spf13/cobra" flag "github.com/spf13/pflag" @@ -53,6 +54,7 @@ func NewRestoreCommand() *cobra.Command { newFullRestoreCommand(), newDbRestoreCommand(), newTableRestoreCommand(), + newRawRestoreCommand(), ) command.PersistentFlags().Uint("concurrency", 128, @@ -90,6 +92,10 @@ func runRestore(flagSet *flag.FlagSet, cmdName, dbName, tableName string) error return errors.Trace(err) } + if client.IsRawKvMode() { + return errors.New("cannot do full restore from raw kv data") + } + files := make([]*backup.File, 0) tables := make([]*utils.Table, 0) @@ -214,6 +220,78 @@ func runRestore(flagSet *flag.FlagSet, cmdName, dbName, tableName string) error return nil } +func runRawRestore(flagSet *flag.FlagSet, startKey, endKey []byte, cf string) error { + ctx, cancel := context.WithCancel(GetDefaultContext()) + defer cancel() + + mgr, err := GetDefaultMgr() + if err != nil { + return err + } + defer mgr.Close() + + client, err := restore.NewRestoreClient( + ctx, mgr.GetPDClient(), mgr.GetTiKV()) + if err != nil { + return errors.Trace(err) + } + defer client.Close() + err = initRestoreClient(ctx, client, flagSet) + if err != nil { + return errors.Trace(err) + } + + if !client.IsRawKvMode() { + return errors.New("cannot do raw restore from transactional data") + } + + files, err := client.GetFilesInRawRange(startKey, endKey, cf) + if err != nil { + return errors.Trace(err) + } + + // Empty rewrite rules + rewriteRules := &restore_util.RewriteRules{} + + ranges, err := restore.ValidateFileRanges(files, rewriteRules) + if err != nil { + return errors.Trace(err) + } + + // Redirect to log if there is no log file to avoid unreadable output. + // TODO: How to show progress? + updateCh := utils.StartProgress( + ctx, + "Table Restore", + // Split/Scatter + Download/Ingest + int64(len(ranges)+len(files)), + !HasLogFile()) + + err = restore.SplitRanges(ctx, client, ranges, rewriteRules, updateCh) + if err != nil { + return errors.Trace(err) + } + + removedSchedulers, err := RestorePrepareWork(ctx, client, mgr) + if err != nil { + return errors.Trace(err) + } + + err = client.RestoreRaw(startKey, endKey, files, updateCh) + if err != nil { + return errors.Trace(err) + } + + err = RestorePostWork(ctx, client, mgr, removedSchedulers) + if err != nil { + return errors.Trace(err) + } + // Restore has finished. + close(updateCh) + + return nil +} + func newFullRestoreCommand() *cobra.Command { command := &cobra.Command{ Use: "full", @@ -276,6 +354,43 @@ func newTableRestoreCommand() *cobra.Command { return command } +func newRawRestoreCommand() *cobra.Command { + command := &cobra.Command{ + Use: "raw", + Short: "restore a raw kv range", + RunE: func(cmd *cobra.Command, _ []string) error { + start, err := cmd.Flags().GetString("start") + if err != nil { + return err + } + startKey, err := utils.ParseKey(cmd.Flags(), start) + if err != nil { + return err + } + end, err := cmd.Flags().GetString("end") + if err != nil { + return err + } + endKey, err := utils.ParseKey(cmd.Flags(), end) + if err != nil { + return err + } + + cf, err := cmd.Flags().GetString("cf") + if err != nil { + return errors.Trace(err) + } + return runRawRestore(cmd.Flags(), startKey, endKey, cf) + }, + } + + command.Flags().StringP("format", "", "raw", "format of raw keys in arguments") + command.Flags().StringP("start", "", "", "restore raw kv start key") + command.Flags().StringP("end", "", "", "restore raw kv end key") + command.Flags().StringP("cf", "", "default", "the cf to restore raw keys") + return command +} + func initRestoreClient(ctx context.Context, client *restore.Client, flagSet *flag.FlagSet) error { u, err := storage.ParseBackendFromFlags(flagSet, FlagStorage) if err != nil { diff --git a/go.mod b/go.mod index 9951c2922..aed6974c3 100644 --- a/go.mod +++ b/go.mod @@ -19,7 +19,7 @@ require ( github.com/onsi/gomega v1.7.1 // indirect github.com/pingcap/check v0.0.0-20191107115940-caf2b9e6ccf4 github.com/pingcap/errors v0.11.4 - github.com/pingcap/kvproto v0.0.0-20191212110315-d6a9d626988c + github.com/pingcap/kvproto v0.0.0-20191217072959-393e6c0fd4b7 github.com/pingcap/log v0.0.0-20191012051959-b742a5d432e9 github.com/pingcap/parser v0.0.0-20191210060830-bdf23a7ade01 github.com/pingcap/pd v1.1.0-beta.0.20191212045800-234784c7a9c5 diff --git a/go.sum b/go.sum index 085e00355..461c4ef12 100644 --- a/go.sum +++ b/go.sum @@ -268,8 +268,8 @@ github.com/pingcap/kvproto v0.0.0-20191030021250-51b332bcb20b/go.mod h1:WWLmULLO github.com/pingcap/kvproto v0.0.0-20191121022655-4c654046831d/go.mod h1:WWLmULLO7l8IOcQG+t+ItJ3fEcrL5FxF0Wu+HrMy26w= github.com/pingcap/kvproto v0.0.0-20191202044712-32be31591b03 h1:IyJl+qesVPf3UfFFmKtX69y1K5KC8uXlot3U0QgH7V4= github.com/pingcap/kvproto v0.0.0-20191202044712-32be31591b03/go.mod h1:WWLmULLO7l8IOcQG+t+ItJ3fEcrL5FxF0Wu+HrMy26w= -github.com/pingcap/kvproto v0.0.0-20191212110315-d6a9d626988c h1:CwVCq7XA/NvTQ6X9ZAhZlvcEvseUsHiPFQf2mL3LVl4= -github.com/pingcap/kvproto v0.0.0-20191212110315-d6a9d626988c/go.mod h1:WWLmULLO7l8IOcQG+t+ItJ3fEcrL5FxF0Wu+HrMy26w= +github.com/pingcap/kvproto v0.0.0-20191217072959-393e6c0fd4b7 h1:thLL2vFObG8vxBCkAmfAbLVBPfXUkBSXaVxppStCrL0= +github.com/pingcap/kvproto v0.0.0-20191217072959-393e6c0fd4b7/go.mod h1:WWLmULLO7l8IOcQG+t+ItJ3fEcrL5FxF0Wu+HrMy26w= github.com/pingcap/log v0.0.0-20190715063458-479153f07ebd h1:hWDol43WY5PGhsh3+8794bFHY1bPrmu6bTalpssCrGg= github.com/pingcap/log v0.0.0-20190715063458-479153f07ebd/go.mod h1:WpHUKhNZ18v116SvGrmjkA9CBhYmuUTKL+p8JC9ANEw= github.com/pingcap/log v0.0.0-20191012051959-b742a5d432e9 h1:AJD9pZYm72vMgPcQDww9rkZ1DnWfl0pXV3BOWlkYIjA= diff --git a/pkg/backup/client.go b/pkg/backup/client.go index 5cba2d9bf..ed5730620 100644 --- a/pkg/backup/client.go +++ b/pkg/backup/client.go @@ -285,6 +285,8 @@ func (bc *Client) BackupRanges( rateLimit uint64, concurrency uint32, updateCh chan<- struct{}, + isRawKv bool, + cf string, ) error { start := time.Now() defer func() { @@ -298,7 +300,7 @@ func (bc *Client) BackupRanges( go func() { for _, r := range ranges { err := bc.backupRange( - ctx, r.StartKey, r.EndKey, lastBackupTS, backupTS, rateLimit, concurrency, updateCh) + ctx, r.StartKey, r.EndKey, lastBackupTS, backupTS, rateLimit, concurrency, updateCh, isRawKv, cf) if err != nil { errCh <- err return @@ -346,6 +348,8 @@ func (bc *Client) backupRange( rateLimit uint64, concurrency uint32, updateCh chan<- struct{}, + isRawKv bool, + cf string, ) (err error) { start := time.Now() defer func() { @@ -381,6 +385,8 @@ func (bc *Client) backupRange( StorageBackend: bc.backend, RateLimit: rateLimit, Concurrency: concurrency, + IsRawKv: isRawKv, + Cf: cf, } push := newPushDown(ctx, bc.mgr, len(allStores)) @@ -402,9 +408,19 @@ func (bc *Client) backupRange( bc.backupMeta.StartVersion = lastBackupTS bc.backupMeta.EndVersion = backupTS - log.Info("backup time range", - zap.Reflect("StartVersion", lastBackupTS), - zap.Reflect("EndVersion", backupTS)) + bc.backupMeta.IsRawKv = isRawKv + if bc.backupMeta.IsRawKv { + bc.backupMeta.RawRanges = append(bc.backupMeta.RawRanges, + &backup.RawRange{StartKey: startKey, EndKey: endKey, Cf: cf}) + log.Info("backup raw ranges", + zap.ByteString("startKey", startKey), + zap.ByteString("endKey", endKey), + zap.String("cf", cf)) + } else { + log.Info("backup time range", + zap.Reflect("StartVersion", backupTS), + zap.Reflect("EndVersion", backupTS)) + } results.tree.Ascend(func(i btree.Item) bool { r := i.(*Range) diff --git a/pkg/restore/client.go b/pkg/restore/client.go index 3030ba857..a0edfae33 100644 --- a/pkg/restore/client.go +++ b/pkg/restore/client.go @@ -1,7 +1,9 @@ package restore import ( + "bytes" "context" + "encoding/hex" "fmt" "math" "sync" @@ -100,19 +102,78 @@ func (rc *Client) Close() { // InitBackupMeta loads schemas from BackupMeta to initialize RestoreClient func (rc *Client) InitBackupMeta(backupMeta *backup.BackupMeta, backend *backup.StorageBackend) error { - databases, err := utils.LoadBackupTables(backupMeta) - if err != nil { - return errors.Trace(err) + if !backupMeta.IsRawKv { + databases, err := utils.LoadBackupTables(backupMeta) + if err != nil { + return errors.Trace(err) + } + rc.databases = databases } - rc.databases = databases rc.backupMeta = backupMeta metaClient := NewSplitClient(rc.pdClient) importClient := NewImportClient(metaClient) - rc.fileImporter = NewFileImporter(rc.ctx, metaClient, importClient, backend, rc.rateLimit) + rc.fileImporter = NewFileImporter(rc.ctx, metaClient, importClient, backend, backupMeta.IsRawKv, rc.rateLimit) return nil } +// IsRawKvMode checks whether the backup data is in raw kv format, in which case transactional recover is forbidden. +func (rc *Client) IsRawKvMode() bool { + return rc.backupMeta.IsRawKv +} + +// GetFilesInRawRange gets all files that are in the given range or intersects with the given range. +func (rc *Client) GetFilesInRawRange(startKey []byte, endKey []byte, cf string) ([]*backup.File, error) { + if !rc.IsRawKvMode() { + return nil, errors.New("the backup data is not in raw kv mode") + } + + for _, rawRange := range rc.backupMeta.RawRanges { + // First check whether the given range is backup-ed. If not, we cannot perform the restore. + if rawRange.Cf != cf { + continue + } + + if (len(rawRange.EndKey) > 0 && bytes.Compare(startKey, rawRange.EndKey) >= 0) || + (len(endKey) > 0 && bytes.Compare(rawRange.StartKey, endKey) >= 0) { + // The restoring range is totally out of the current range. Skip it. + continue + } + + if bytes.Compare(startKey, rawRange.StartKey) < 0 || + utils.CompareEndKey(endKey, rawRange.EndKey) > 0 { + // Only partial of the restoring range is in the current backup-ed range. So the given range can't be fully + // restored. + return nil, errors.New("no backup data in the range") + } + + // We have found the range that contains the given range. Find all necessary files. + files := make([]*backup.File, 0) + + for _, file := range rc.backupMeta.Files { + if file.Cf != cf { + continue + } + + if len(file.EndKey) > 0 && bytes.Compare(file.EndKey, startKey) < 0 { + // The file is before the range to be restored. + continue + } + if len(endKey) > 0 && bytes.Compare(endKey, file.StartKey) >= 0 { + // The file is after the range to be restored. + // The specified endKey is exclusive, so when it equals to a file's startKey, the file is still skipped. + continue + } + + files = append(files, file) + } + + return files, nil + } + + return nil, errors.New("no backup data in the range") +} + // SetConcurrency sets the concurrency of dbs tables files func (rc *Client) SetConcurrency(c uint) { rc.workerPool = utils.NewWorkerPool(c, "file") @@ -370,6 +431,63 @@ func (rc *Client) RestoreAll( return nil } +// RestoreRaw tries to restore raw keys in the specified range. +func (rc *Client) RestoreRaw(startKey []byte, endKey []byte, files []*backup.File, updateCh chan<- struct{}) error { + start := time.Now() + defer func() { + elapsed := time.Since(start) + log.Info("Restore Raw", + zap.String("startKey", hex.EncodeToString(startKey)), + zap.String("endKey", hex.EncodeToString(endKey)), + zap.Duration("take", elapsed)) + }() + errCh := make(chan error, len(rc.databases)) + wg := new(sync.WaitGroup) + defer close(errCh) + + err := rc.fileImporter.SetRawRange(startKey, endKey) + if err != nil { + + return errors.Trace(err) + } + + emptyRules := &restore_util.RewriteRules{} + for _, file := range files { + wg.Add(1) + fileReplica := file + rc.workerPool.Apply( + func() { + defer wg.Done() + select { + case <-rc.ctx.Done(): + errCh <- nil + case errCh <- rc.fileImporter.Import(fileReplica, emptyRules): + updateCh <- struct{}{} + } + }) + } + for range files { + err := <-errCh + if err != nil { + rc.cancel() + wg.Wait() + log.Error( + "restore raw range failed", + zap.String("startKey", hex.EncodeToString(startKey)), + zap.String("endKey", hex.EncodeToString(endKey)), + zap.Error(err), + ) + return err + } + } + log.Info( + "finish to restore raw range", + zap.String("startKey", hex.EncodeToString(startKey)), + zap.String("endKey", hex.EncodeToString(endKey)), + ) + return nil +} + //SwitchToImportMode switch tikv cluster to import mode func (rc *Client) SwitchToImportMode(ctx context.Context) error { return rc.switchTiKVMode(ctx, import_sstpb.SwitchMode_Import) diff --git a/pkg/restore/import.go b/pkg/restore/import.go index 77273ebab..b22048e75 100644 --- a/pkg/restore/import.go +++ b/pkg/restore/import.go @@ -1,6 +1,7 @@ package restore import ( + "bytes" "context" "sync" "time" @@ -137,6 +138,10 @@ type FileImporter struct { backend *backup.StorageBackend rateLimit uint64 + isRawKvMode bool + rawStartKey []byte + rawEndKey []byte + ctx context.Context cancel context.CancelFunc } @@ -147,6 +152,7 @@ func NewFileImporter( metaClient SplitClient, importClient ImporterClient, backend *backup.StorageBackend, + isRawKvMode bool, rateLimit uint64, ) FileImporter { ctx, cancel := context.WithCancel(ctx) @@ -156,10 +162,21 @@ func NewFileImporter( ctx: ctx, cancel: cancel, importClient: importClient, + isRawKvMode: isRawKvMode, rateLimit: rateLimit, } } +// SetRawRange sets the range to be restored in raw kv mode. +func (importer *FileImporter) SetRawRange(startKey, endKey []byte) error { + if !importer.isRawKvMode { + return errors.New("file importer is not in raw kv mode") + } + importer.rawStartKey = startKey + importer.rawEndKey = endKey + return nil +} + // Import tries to import a file. // All rules must contain encoded keys. func (importer *FileImporter) Import(file *backup.File, rewriteRules *RewriteRules) error { @@ -277,13 +294,27 @@ func (importer *FileImporter) downloadSST( ) return nil, true, errRewriteRuleNotFound } - rule := import_sstpb.RewriteRule{ + rule := import_sstpb.RewriteRule{ OldKeyPrefix: encodeKeyPrefix(regionRule.GetOldKeyPrefix()), NewKeyPrefix: encodeKeyPrefix(regionRule.GetNewKeyPrefix()), } sstMeta := getSSTMetaFromFile(id, file, regionInfo.Region, &rule) - sstMeta.RegionId = regionInfo.Region.GetId() + sstMeta.RegionId = regionInfo.Region.GetId() sstMeta.RegionEpoch = regionInfo.Region.GetRegionEpoch() + // For raw kv mode, cut the SST file's range to fit in the restoring range. + if importer.isRawKvMode { + if bytes.Compare(importer.rawStartKey, sstMeta.Range.GetStart()) > 0 { + sstMeta.Range.Start = importer.rawStartKey + } + // TODO: importer.RawEndKey is exclusive but sstMeta.Range.End is inclusive. How to exclude importer.RawEndKey? + if len(importer.rawEndKey) > 0 && bytes.Compare(importer.rawEndKey, sstMeta.Range.GetEnd()) < 0 { + sstMeta.Range.End = importer.rawEndKey + } + if bytes.Compare(sstMeta.Range.GetStart(), sstMeta.Range.GetEnd()) > 0 { + return &sstMeta, true, nil + } + } + req := &import_sstpb.DownloadRequest{ Sst: sstMeta, StorageBackend: importer.backend, diff --git a/pkg/utils/key.go b/pkg/utils/key.go new file mode 100644 index 000000000..59a35f743 --- /dev/null +++ b/pkg/utils/key.go @@ -0,0 +1,70 @@ +package utils + +import ( + "bytes" + "encoding/hex" + "fmt" + "io" + "strings" + + "github.com/pingcap/errors" + "github.com/spf13/pflag" +) + +// ParseKey parse key by given format +func ParseKey(flags *pflag.FlagSet, key string) ([]byte, error) { + switch flags.Lookup("format").Value.String() { + case "raw": + return []byte(key), nil + case "escaped": + return unescapedKey(key) + case "hex": + key, err := hex.DecodeString(key) + if err != nil { + return nil, errors.WithStack(err) + } + return key, nil + } + return nil, errors.New("unknown format") +} + +func unescapedKey(text string) ([]byte, error) { + var buf []byte + r := bytes.NewBuffer([]byte(text)) + for { + c, err := r.ReadByte() + if err != nil { + if err != io.EOF { + return nil, errors.WithStack(err) + } + break + } + if c != '\\' { + buf = append(buf, c) + continue + } + n := r.Next(1) + if len(n) == 0 { + return nil, io.EOF + } + // See: https://golang.org/ref/spec#Rune_literals + if idx := strings.IndexByte(`abfnrtv\'"`, n[0]); idx != -1 { + buf = append(buf, []byte("\a\b\f\n\r\t\v\\'\"")[idx]) + continue + } + + switch n[0] { + case 'x': + fmt.Sscanf(string(r.Next(2)), "%02x", &c) + buf = append(buf, c) + default: + n = append(n, r.Next(2)...) + _, err := fmt.Sscanf(string(n), "%03o", &c) + if err != nil { + return nil, errors.WithStack(err) + } + buf = append(buf, c) + } + } + return buf, nil +} diff --git a/pkg/utils/key_test.go b/pkg/utils/key_test.go new file mode 100644 index 000000000..c983d34e5 --- /dev/null +++ b/pkg/utils/key_test.go @@ -0,0 +1,41 @@ +package utils + +import ( + "encoding/hex" + + . "github.com/pingcap/check" + "github.com/spf13/pflag" +) + +type testKeySuite struct{} + +var _ = Suite(&testKeySuite{}) + +func (r *testKeySuite) TestParseKey(c *C) { + flagSet := &pflag.FlagSet{} + flagSet.String("format", "raw", "") + rawKey := "1234" + parsedKey, err := ParseKey(flagSet, rawKey) + c.Assert(err, IsNil) + c.Assert(parsedKey, BytesEquals, []byte(rawKey)) + + flagSet = &pflag.FlagSet{} + flagSet.String("format", "escaped", "") + escapedKey := "\\a\\x1" + parsedKey, err = ParseKey(flagSet, escapedKey) + c.Assert(err, IsNil) + c.Assert(parsedKey, BytesEquals, []byte("\a\x01")) + + flagSet = &pflag.FlagSet{} + flagSet.String("format", "hex", "") + hexKey := hex.EncodeToString([]byte("1234")) + parsedKey, err = ParseKey(flagSet, hexKey) + c.Assert(err, IsNil) + c.Assert(parsedKey, BytesEquals, []byte("1234")) + + flagSet = &pflag.FlagSet{} + flagSet.String("format", "notSupport", "") + _, err = ParseKey(flagSet, rawKey) + c.Assert(err, ErrorMatches, "*unknown format*") + +} diff --git a/pkg/utils/keys.go b/pkg/utils/keys.go new file mode 100644 index 000000000..f03a21d25 --- /dev/null +++ b/pkg/utils/keys.go @@ -0,0 +1,21 @@ +package utils + +import "bytes" + +// CompareEndKey compared two keys that BOTH represent the EXCLUSIVE ending of some range. An empty end key is the very +// end, so an empty key is greater than any other keys. +// Please note that this function is not applicable if any one argument is not an EXCLUSIVE ending of a range. +func CompareEndKey(a, b []byte) int { + if len(a) == 0 { + if len(b) == 0 { + return 0 + } + return 1 + } + + if len(b) == 0 { + return -1 + } + + return bytes.Compare(a, b) +}