diff --git a/models/channel_watch.go b/models/channel_watch.go index b8ead290..9417586c 100644 --- a/models/channel_watch.go +++ b/models/channel_watch.go @@ -3,6 +3,7 @@ package models import ( "github.com/samber/lo" + datapbv2 "github.com/milvus-io/birdwatcher/proto/v2.2/datapb" "github.com/milvus-io/birdwatcher/proto/v2.2/schemapb" ) @@ -12,9 +13,15 @@ type ChannelWatch struct { State ChannelWatchState TimeoutTs int64 + // 2.4 only + Progress int32 + OpID int64 + // key key string - Schema CollectionSchema + Schema *CollectionSchema + + VchanV2Pb *datapbv2.VchannelInfo } func (c *ChannelWatch) Key() string { @@ -57,29 +64,36 @@ func GetChannelWatchInfo[ChannelWatchBase interface { } func GetChannelWatchInfoV2[ChannelWatchBase interface { - GetVchan() vchan + GetVchan() *datapbv2.VchannelInfo GetStartTs() int64 GetState() watchState GetTimeoutTs() int64 GetSchema() *schemapb.CollectionSchema -}, watchState ~int32, vchan interface { - vchannelInfoBase - GetSeekPosition() pos -}, pos msgPosBase](info ChannelWatchBase, key string) *ChannelWatch { - schema := newSchemaFromBase(info.GetSchema()) - schema.Fields = lo.Map(info.GetSchema().GetFields(), func(fieldSchema *schemapb.FieldSchema, _ int) FieldSchema { - fs := NewFieldSchemaFromBase[*schemapb.FieldSchema, schemapb.DataType](fieldSchema) - fs.Properties = GetMapFromKVPairs(fieldSchema.GetTypeParams()) - return fs - }) + GetProgress() int32 + GetOpID() int64 +}, watchState ~int32, pos msgPosBase](info ChannelWatchBase, key string) *ChannelWatch { + var schema *CollectionSchema + if info.GetSchema() != nil { + m := newSchemaFromBase(info.GetSchema()) + schema = &m + schema.Fields = lo.Map(info.GetSchema().GetFields(), func(fieldSchema *schemapb.FieldSchema, _ int) FieldSchema { + fs := NewFieldSchemaFromBase[*schemapb.FieldSchema, schemapb.DataType](fieldSchema) + fs.Properties = GetMapFromKVPairs(fieldSchema.GetTypeParams()) + return fs + }) + } return &ChannelWatch{ - Vchan: getVChannelInfo[vchan, pos](info.GetVchan()), + Vchan: getVChannelInfo(info.GetVchan()), StartTs: info.GetStartTs(), State: ChannelWatchState(info.GetState()), TimeoutTs: info.GetTimeoutTs(), key: key, Schema: schema, + Progress: info.GetProgress(), + OpID: info.GetOpID(), + + VchanV2Pb: info.GetVchan(), } } diff --git a/models/collection.go b/models/collection.go index 24c34214..9f6f9ac8 100644 --- a/models/collection.go +++ b/models/collection.go @@ -28,6 +28,8 @@ type Collection struct { Properties map[string]string DBID int64 + CollectionPBv2 *schemapbv2.CollectionSchema + // etcd collection key key string @@ -114,6 +116,7 @@ func NewCollectionFromV2_2(info *etcdpbv2.CollectionInfo, key string, fields []* schema := info.GetSchema() schema.Fields = fields c.Schema = newSchemaFromBase(schema) + c.CollectionPBv2 = schema c.Schema.Fields = lo.Map(fields, func(fieldSchema *schemapbv2.FieldSchema, _ int) FieldSchema { fs := NewFieldSchemaFromBase[*schemapbv2.FieldSchema, schemapbv2.DataType](fieldSchema) diff --git a/states/backup_mock_connect.go b/states/backup_mock_connect.go index d711794a..b2bb57b6 100644 --- a/states/backup_mock_connect.go +++ b/states/backup_mock_connect.go @@ -21,6 +21,7 @@ import ( "github.com/milvus-io/birdwatcher/models" "github.com/milvus-io/birdwatcher/states/etcd" "github.com/milvus-io/birdwatcher/states/etcd/remove" + "github.com/milvus-io/birdwatcher/states/etcd/repair" "github.com/milvus-io/birdwatcher/states/etcd/show" ) @@ -32,6 +33,7 @@ type embedEtcdMockState struct { cmdState *show.ComponentShow *remove.ComponentRemove + *repair.ComponentRepair client *clientv3.Client server *embed.Etcd instanceName string @@ -90,6 +92,7 @@ func (s *embedEtcdMockState) SetInstance(instanceName string) { rootPath := path.Join(instanceName, metaPath) s.ComponentShow = show.NewComponent(s.client, s.config, rootPath) s.ComponentRemove = remove.NewComponent(s.client, s.config, rootPath) + s.ComponentRepair = repair.NewComponent(s.client, s.config, rootPath) s.SetupCommands() } diff --git a/states/etcd/common/channel.go b/states/etcd/common/channel.go index efd812fe..b3f29ed3 100644 --- a/states/etcd/common/channel.go +++ b/states/etcd/common/channel.go @@ -58,7 +58,8 @@ func ListChannelWatch(ctx context.Context, cli clientv3.KV, basePath string, ver return nil, err } result = lo.Map(infos, func(info datapb.ChannelWatchInfo, idx int) *models.ChannelWatch { - return models.GetChannelWatchInfo[*datapb.ChannelWatchInfo, datapb.ChannelWatchState, *datapb.VchannelInfo, *internalpb.MsgPosition](&info, paths[idx]) + result := models.GetChannelWatchInfo[*datapb.ChannelWatchInfo, datapb.ChannelWatchState, *datapb.VchannelInfo, *internalpb.MsgPosition](&info, paths[idx]) + return result }) case models.GTEVersion2_2: infos, paths, err := ListProtoObjects[datapbv2.ChannelWatchInfo](ctx, cli, prefix) @@ -66,7 +67,7 @@ func ListChannelWatch(ctx context.Context, cli clientv3.KV, basePath string, ver return nil, err } result = lo.Map(infos, func(info datapbv2.ChannelWatchInfo, idx int) *models.ChannelWatch { - return models.GetChannelWatchInfoV2[*datapbv2.ChannelWatchInfo, datapbv2.ChannelWatchState, *datapbv2.VchannelInfo, *msgpbv2.MsgPosition](&info, paths[idx]) + return models.GetChannelWatchInfoV2[*datapbv2.ChannelWatchInfo, datapbv2.ChannelWatchState, *msgpbv2.MsgPosition](&info, paths[idx]) }) default: return nil, errors.New("version not supported") diff --git a/states/etcd/common/channel_watch.go b/states/etcd/common/channel_watch.go new file mode 100644 index 00000000..d2a34a84 --- /dev/null +++ b/states/etcd/common/channel_watch.go @@ -0,0 +1,30 @@ +package common + +import ( + "context" + + "github.com/golang/protobuf/proto" + clientv3 "go.etcd.io/etcd/client/v3" + + "github.com/milvus-io/birdwatcher/models" + datapbv2 "github.com/milvus-io/birdwatcher/proto/v2.2/datapb" + schemapbv2 "github.com/milvus-io/birdwatcher/proto/v2.2/schemapb" +) + +func WriteChannelWatchInfo(ctx context.Context, cli clientv3.KV, basePath string, info *models.ChannelWatch, schema *schemapbv2.CollectionSchema) error { + pb := &datapbv2.ChannelWatchInfo{ + Vchan: info.VchanV2Pb, + StartTs: info.StartTs, + State: datapbv2.ChannelWatchState(info.State), + TimeoutTs: info.TimeoutTs, + Schema: schema, // use passed schema + Progress: info.Progress, + OpID: info.OpID, + } + bs, err := proto.Marshal(pb) + if err != nil { + return err + } + _, err = cli.Put(ctx, info.Key(), string(bs)) + return err +} diff --git a/states/etcd/common/collection.go b/states/etcd/common/collection.go index 0562586e..469fe839 100644 --- a/states/etcd/common/collection.go +++ b/states/etcd/common/collection.go @@ -121,18 +121,18 @@ func ListCollectionsVersion(ctx context.Context, cli clientv3.KV, basePath strin // GetCollectionByIDVersion retruns collection info from etcd with provided version & id. func GetCollectionByIDVersion(ctx context.Context, cli clientv3.KV, basePath string, version string, collID int64) (*models.Collection, error) { - var result []*mvccpb.KeyValue - // meta before database + var legacy []*mvccpb.KeyValue prefix := path.Join(basePath, CollectionMetaPrefix, strconv.FormatInt(collID, 10)) resp, err := cli.Get(ctx, prefix) if err != nil { fmt.Println("get error", err.Error()) return nil, err } - result = append(result, resp.Kvs...) + legacy = append(legacy, resp.Kvs...) // with database, dbID unknown here + var result []*mvccpb.KeyValue prefix = path.Join(basePath, DBCollectionMetaPrefix) resp, _ = cli.Get(ctx, prefix, clientv3.WithPrefix()) suffix := strconv.FormatInt(collID, 10) @@ -142,11 +142,16 @@ func GetCollectionByIDVersion(ctx context.Context, cli clientv3.KV, basePath str } } - if len(result) != 1 { + if len(legacy)+len(result) == 0 { return nil, fmt.Errorf("collection %d not found in etcd %w", collID, ErrCollectionNotFound) } - kv := result[0] + var kv *mvccpb.KeyValue + if len(result) > 0 { + kv = result[0] + } else { + kv = legacy[0] + } if bytes.Equal(kv.Value, CollectionTombstone) { return nil, fmt.Errorf("%w, collection id: %d", ErrCollectionDropped, collID) diff --git a/states/etcd/common/list.go b/states/etcd/common/list.go index 00996120..bdc46750 100644 --- a/states/etcd/common/list.go +++ b/states/etcd/common/list.go @@ -1,6 +1,7 @@ package common import ( + "bytes" "context" "fmt" @@ -26,7 +27,11 @@ LOOP: info := P(&elem) err = proto.Unmarshal(kv.Value, info) if err != nil { - fmt.Println(err.Error()) + if bytes.Equal(kv.Value, []byte{0xE2, 0x9B, 0xBC}) { + fmt.Printf("Tombstone found, key: %s\n", string(kv.Key)) + continue + } + fmt.Printf("failed to unmarshal key=%s, err: %s\n", string(kv.Key), err.Error()) continue } diff --git a/states/etcd/repair/channel_watched.go b/states/etcd/repair/channel_watched.go new file mode 100644 index 00000000..ffc2fa04 --- /dev/null +++ b/states/etcd/repair/channel_watched.go @@ -0,0 +1,103 @@ +package repair + +import ( + "context" + "fmt" + "sort" + "strings" + + "github.com/pkg/errors" + + "github.com/milvus-io/birdwatcher/framework" + "github.com/milvus-io/birdwatcher/models" + "github.com/milvus-io/birdwatcher/states/etcd/common" + etcdversion "github.com/milvus-io/birdwatcher/states/etcd/version" +) + +type ChannelWatchedParam struct { + framework.ParamBase `use:"repair channel-watch"` + CollectionID int64 `name:"collection" default:"0" desc:"collection id to repair"` + ChannelName string `name:"vchannel" default:"" desc:"channel name to repair"` + Run bool `name:"run" default:"false" desc:"whether to remove legacy collection meta, default set to \"false\" to dry run"` +} + +func (c *ComponentRepair) RepairChannelWatchedCommand(ctx context.Context, p *ChannelWatchedParam) error { + infos, err := common.ListChannelWatch(ctx, c.client, c.basePath, etcdversion.GetVersion(), func(channel *models.ChannelWatch) bool { + return (p.CollectionID == 0 || channel.Vchan.CollectionID == p.CollectionID) && + (p.ChannelName == "" || channel.Vchan.ChannelName == p.ChannelName) + }) + if err != nil { + return errors.Wrap(err, "failed to list channel watch info") + } + + var targets []*models.ChannelWatch + + for _, info := range infos { + if info.Schema == nil { + targets = append(targets, info) + } + } + + if len(targets) == 0 { + fmt.Println("No empty schema watch info found") + return nil + } + + for _, info := range targets { + fmt.Println("=================================================================") + fmt.Printf("Watch info with empty schema found, channel name = %s, key = %s", info.Vchan.ChannelName, info.Key()) + + collection, err := common.GetCollectionByIDVersion(ctx, c.client, c.basePath, etcdversion.GetVersion(), info.Vchan.CollectionID) + if err != nil { + fmt.Println("failed to get collection schema: ", err.Error()) + } + sb := &strings.Builder{} + info.Schema = &collection.Schema + printSchema(sb, info) + fmt.Println("Collection schema found, about to set schema as:") + fmt.Println(sb.String()) + if p.Run { + err := common.WriteChannelWatchInfo(ctx, c.client, c.basePath, info, collection.CollectionPBv2) + if err != nil { + fmt.Println("failed to write modified channel watch info, err: ", err.Error()) + continue + } + fmt.Println("Modified channel watch info written!") + } + } + + return nil +} + +func printSchema(sb *strings.Builder, info *models.ChannelWatch) { + fmt.Fprintf(sb, "Fields:\n") + fields := info.Schema.Fields + sort.Slice(fields, func(i, j int) bool { + return fields[i].FieldID < fields[j].FieldID + }) + for _, field := range fields { + fmt.Fprintf(sb, " - Field ID: %d \t Field Name: %s \t Field Type: %s\n", field.FieldID, field.Name, field.DataType.String()) + if field.IsPrimaryKey { + fmt.Fprintf(sb, "\t - Primary Key: %t, AutoID: %t\n", field.IsPrimaryKey, field.AutoID) + } + if field.IsDynamic { + fmt.Fprintf(sb, "\t - Dynamic Field\n") + } + if field.IsPartitionKey { + fmt.Fprintf(sb, "\t - Partition Key\n") + } + if field.IsClusteringKey { + fmt.Fprintf(sb, "\t - Clustering Key\n") + } + // print element type if field is array + if field.DataType == models.DataTypeArray { + fmt.Fprintf(sb, "\t - Element Type: %s\n", field.ElementType.String()) + } + // type params + for key, value := range field.Properties { + fmt.Fprintf(sb, "\t - Type Param %s: %s\n", key, value) + } + } + + fmt.Fprintf(sb, "Enable Dynamic Schema: %t\n", info.Schema.EnableDynamicSchema) +} diff --git a/states/etcd/restore/collection.go b/states/etcd/restore/collection.go new file mode 100644 index 00000000..b8a7f1ed --- /dev/null +++ b/states/etcd/restore/collection.go @@ -0,0 +1,86 @@ +package restore + +import ( + "context" + "fmt" + + "github.com/spf13/cobra" + clientv3 "go.etcd.io/etcd/client/v3" + + "github.com/milvus-io/birdwatcher/models" + commonpbv2 "github.com/milvus-io/birdwatcher/proto/v2.2/commonpb" + datapbv2 "github.com/milvus-io/birdwatcher/proto/v2.2/datapb" + etcdpbv2 "github.com/milvus-io/birdwatcher/proto/v2.2/etcdpb" + "github.com/milvus-io/birdwatcher/states/etcd/common" + etcdversion "github.com/milvus-io/birdwatcher/states/etcd/version" +) + +// CollectionCommand returns sub command for restore command. +// restore collection [options...] +func CollectionCommand(cli clientv3.KV, basePath string) *cobra.Command { + cmd := &cobra.Command{ + Use: "collection", + Short: "restore dropping/dropped collection meta", + Run: func(cmd *cobra.Command, args []string) { + collectionID, err := cmd.Flags().GetInt64("id") + if err != nil { + fmt.Println(err.Error()) + return + } + + if collectionID == 0 { + fmt.Println("collection id not provided") + return + } + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + var collection *models.Collection + // TODO check history as well for dropped collection + collection, err = common.GetCollectionByIDVersion(ctx, cli, basePath, etcdversion.GetVersion(), collectionID) + if err != nil { + fmt.Println("faile to get collection by id", err.Error()) + return + } + + updateCollState := func(collection *etcdpbv2.CollectionInfo) { + collection.State = etcdpbv2.CollectionState_CollectionCreated + } + + if collection.DBID > 0 { + // err = common.UpdateCollectionWithDB(ctx, cli, basePath, collectionID, collection.DBID, updateCollState) + } else { + err = common.UpdateCollection(ctx, cli, basePath, collectionID, updateCollState, false) + } + + if err != nil { + fmt.Println(err.Error()) + return + } + err = common.RemoveCollectionHistory(ctx, cli, basePath, etcdversion.GetVersion(), collectionID) + if err != nil { + fmt.Println(err.Error()) + return + } + + err = common.UpdateSegments(ctx, cli, basePath, collectionID, func(segment *datapbv2.SegmentInfo) { + segment.State = commonpbv2.SegmentState_Flushed + }) + if err != nil { + fmt.Println(err.Error()) + return + } + + for _, channel := range collection.Channels { + err = common.SetChannelWatch(ctx, cli, basePath, channel.VirtualName, collection) + if err != nil { + fmt.Println(err.Error()) + } + } + }, + } + + cmd.Flags().Int64("id", 0, "collection id to restore") + return cmd +} diff --git a/states/etcd/show/channel_watched.go b/states/etcd/show/channel_watched.go index 021ae6f6..a98965c8 100644 --- a/states/etcd/show/channel_watched.go +++ b/states/etcd/show/channel_watched.go @@ -56,8 +56,7 @@ func (rs *ChannelsWatched) printChannelWatchInfo(sb *strings.Builder, info *mode fmt.Fprintln(sb, "=============================") fmt.Fprintf(sb, "key: %s\n", info.Key()) fmt.Fprintf(sb, "Channel Name:%s \t WatchState: %s\n", info.Vchan.ChannelName, info.State.String()) - // t, _ := ParseTS(uint64(info.GetStartTs())) - // to, _ := ParseTS(uint64(info.GetTimeoutTs())) + t := time.Unix(info.StartTs, 0) to := time.Unix(0, info.TimeoutTs) fmt.Fprintf(sb, "Channel Watch start from: %s, timeout at: %s\n", t.Format(tsPrintFormat), to.Format(tsPrintFormat)) @@ -73,6 +72,10 @@ func (rs *ChannelsWatched) printChannelWatchInfo(sb *strings.Builder, info *mode fmt.Fprintf(sb, "Dropped segments: %v\n", info.Vchan.DroppedSegmentIds) fmt.Fprintf(sb, "Fields:\n") + if info.Schema == nil { + fmt.Fprintf(sb, "### Collection schema is empty!!! ###\n") + return + } fields := info.Schema.Fields sort.Slice(fields, func(i, j int) bool { return fields[i].FieldID < fields[j].FieldID