diff --git a/pkg/ccl/backupccl/restore_planning.go b/pkg/ccl/backupccl/restore_planning.go index d1bf9496265f..052d75e85f5e 100644 --- a/pkg/ccl/backupccl/restore_planning.go +++ b/pkg/ccl/backupccl/restore_planning.go @@ -116,14 +116,23 @@ func rewriteTypesInExpr(expr string, rewrites DescRewriteMap) (string, error) { if err != nil { return "", err } + ctx := tree.NewFmtCtx(tree.FmtSerializable) ctx.SetIndexedTypeFormat(func(ctx *tree.FmtCtx, ref *tree.OIDTypeReference) { newRef := ref - if rw, ok := rewrites[typedesc.UserDefinedTypeOIDToID(ref.OID)]; ok { + var id descpb.ID + id, err = typedesc.UserDefinedTypeOIDToID(ref.OID) + if err != nil { + return + } + if rw, ok := rewrites[id]; ok { newRef = &tree.OIDTypeReference{OID: typedesc.TypeIDToOID(rw.ID)} } ctx.WriteString(newRef.SQLString()) }) + if err != nil { + return "", err + } ctx.FormatNode(parsed) return ctx.CloseAndGetString(), nil } @@ -348,11 +357,15 @@ func allocateDescriptorRewrites( // Ensure that all referenced types are present. if col.Type.UserDefined() { // TODO (rohany): This can be turned into an option later. - if _, ok := typesByID[typedesc.GetTypeDescID(col.Type)]; !ok { + id, err := typedesc.GetUserDefinedTypeDescID(col.Type) + if err != nil { + return nil, err + } + if _, ok := typesByID[id]; !ok { return nil, errors.Errorf( "cannot restore table %q without referenced type %d", table.Name, - typedesc.GetTypeDescID(col.Type), + id, ) } } @@ -1025,25 +1038,37 @@ func rewriteDatabaseDescs(databases []*dbdesc.Mutable, descriptorRewrites DescRe // rewriteIDsInTypesT rewrites all ID's in the input types.T using the input // ID rewrite mapping. -func rewriteIDsInTypesT(typ *types.T, descriptorRewrites DescRewriteMap) { +func rewriteIDsInTypesT(typ *types.T, descriptorRewrites DescRewriteMap) error { if !typ.UserDefined() { - return + return nil + } + tid, err := typedesc.GetUserDefinedTypeDescID(typ) + if err != nil { + return err } // Collect potential new OID values. var newOID, newArrayOID oid.Oid - if rw, ok := descriptorRewrites[typedesc.GetTypeDescID(typ)]; ok { + if rw, ok := descriptorRewrites[tid]; ok { newOID = typedesc.TypeIDToOID(rw.ID) } if typ.Family() != types.ArrayFamily { - if rw, ok := descriptorRewrites[typedesc.GetArrayTypeDescID(typ)]; ok { + tid, err = typedesc.GetUserDefinedArrayTypeDescID(typ) + if err != nil { + return err + } + if rw, ok := descriptorRewrites[tid]; ok { newArrayOID = typedesc.TypeIDToOID(rw.ID) } } types.RemapUserDefinedTypeOIDs(typ, newOID, newArrayOID) // If the type is an array, then we need to rewrite the element type as well. if typ.Family() == types.ArrayFamily { - rewriteIDsInTypesT(typ.ArrayContents(), descriptorRewrites) + if err := rewriteIDsInTypesT(typ.ArrayContents(), descriptorRewrites); err != nil { + return err + } } + + return nil } // rewriteTypeDescs rewrites all ID's in the input slice of TypeDescriptors @@ -1075,7 +1100,9 @@ func rewriteTypeDescs(types []*typedesc.Mutable, descriptorRewrites DescRewriteM } case descpb.TypeDescriptor_ALIAS: // We need to rewrite any ID's present in the aliased types.T. - rewriteIDsInTypesT(typ.Alias, descriptorRewrites) + if err := rewriteIDsInTypesT(typ.Alias, descriptorRewrites); err != nil { + return err + } default: return errors.AssertionFailedf("unknown type kind %s", t.String()) } @@ -1285,7 +1312,9 @@ func RewriteTableDescs( // rewriteCol is a closure that performs the ID rewrite logic on a column. rewriteCol := func(col *descpb.ColumnDescriptor) error { // Rewrite the types.T's IDs present in the column. - rewriteIDsInTypesT(col.Type, descriptorRewrites) + if err := rewriteIDsInTypesT(col.Type, descriptorRewrites); err != nil { + return err + } var newUsedSeqRefs []descpb.ID for _, seqID := range col.UsesSequenceIds { if rewrite, ok := descriptorRewrites[seqID]; ok { diff --git a/pkg/ccl/changefeedccl/changefeed_processors.go b/pkg/ccl/changefeedccl/changefeed_processors.go index 7a6b4e7fc6cf..24fd3e4b0e9e 100644 --- a/pkg/ccl/changefeedccl/changefeed_processors.go +++ b/pkg/ccl/changefeedccl/changefeed_processors.go @@ -79,13 +79,20 @@ type changeAggregator struct { eventProducer kvEventProducer // eventConsumer consumes the event. eventConsumer kvEventConsumer + + // flush related fields: clock to obtain current hlc time, // lastFlush and flushFrequency keep track of the flush frequency. - lastFlush time.Time + clock *hlc.Clock + + lastFlush hlc.Timestamp flushFrequency time.Duration - // spansToFlush keeps track of resolved spans that have not been flushed yet. - spansToFlush []*jobspb.ResolvedSpan - // spanFrontier keeps track of resolved timestamps for spans. - spanFrontier *span.Frontier + + // frontier keeps track of resolved timestamps for spans along with schema change + // boundary information. + frontier *schemaChangeFrontier + + // number of frontier updates since the last flush. + numFrontierUpdates int metrics *Metrics knobs TestingKnobs @@ -132,11 +139,13 @@ func newChangeAggregatorProcessor( ) (execinfra.Processor, error) { ctx := flowCtx.EvalCtx.Ctx() memMonitor := execinfra.NewMonitor(ctx, flowCtx.EvalCtx.Mon, "changeagg-mem") + clock := flowCtx.Cfg.DB.Clock() ca := &changeAggregator{ flowCtx: flowCtx, spec: spec, memAcc: memMonitor.MakeBoundAccount(), - lastFlush: timeutil.Now(), + clock: clock, + lastFlush: clock.Now(), } if err := ca.Init( ca, @@ -210,7 +219,10 @@ func (ca *changeAggregator) Start(ctx context.Context) { ca.cancel() return } - timestampOracle := &changeAggregatorLowerBoundOracle{sf: ca.spanFrontier, initialInclusiveLowerBound: ca.spec.Feed.StatementTime} + timestampOracle := &changeAggregatorLowerBoundOracle{ + sf: ca.frontier.SpanFrontier(), + initialInclusiveLowerBound: ca.spec.Feed.StatementTime, + } if cfKnobs, ok := ca.flowCtx.TestingKnobs().Changefeed.(*TestingKnobs); ok { ca.knobs = *cfKnobs @@ -263,7 +275,8 @@ func (ca *changeAggregator) Start(ctx context.Context) { if ca.spec.Feed.Opts[changefeedbase.OptFormat] == string(changefeedbase.OptFormatNative) { ca.eventConsumer = newNativeKVConsumer(ca.sink) } else { - ca.eventConsumer = newKVEventToRowConsumer(ctx, cfg, ca.spanFrontier, kvfeedCfg.InitialHighWater, + ca.eventConsumer = newKVEventToRowConsumer( + ctx, cfg, ca.frontier.SpanFrontier(), kvfeedCfg.InitialHighWater, ca.sink, ca.encoder, ca.spec.Feed, ca.knobs) } @@ -448,12 +461,12 @@ func (ca *changeAggregator) setupSpansAndFrontier() (spans []roachpb.Span, err e spans = append(spans, watch.Span) } - ca.spanFrontier, err = span.MakeFrontier(spans...) + ca.frontier, err = makeSchemaChangeFrontier(ca.clock, spans...) if err != nil { return nil, err } for _, watch := range ca.spec.Watches { - if _, err := ca.spanFrontier.Forward(watch.Span, watch.InitialResolved); err != nil { + if _, err := ca.frontier.Forward(watch.Span, watch.InitialResolved); err != nil { return nil, err } } @@ -546,10 +559,10 @@ func (ca *changeAggregator) tick() error { } case kvfeed.ResolvedEvent: resolved := event.Resolved() - if _, err := ca.spanFrontier.Forward(resolved.Span, resolved.Timestamp); err != nil { + if _, err := ca.frontier.ForwardResolvedSpan(*resolved); err != nil { return err } - ca.spansToFlush = append(ca.spansToFlush, resolved) + ca.numFrontierUpdates++ forceFlush = resolved.BoundaryType != jobspb.ResolvedSpan_NONE } @@ -558,7 +571,7 @@ func (ca *changeAggregator) tick() error { // maybeFlush flushes sink and emits resolved timestamp if needed. func (ca *changeAggregator) maybeFlush(force bool) error { - if len(ca.spansToFlush) == 0 || (timeutil.Since(ca.lastFlush) < ca.flushFrequency && !force) { + if ca.numFrontierUpdates == 0 || (timeutil.Since(ca.lastFlush.GoTime()) < ca.flushFrequency && !force) { return nil } @@ -569,28 +582,48 @@ func (ca *changeAggregator) maybeFlush(force bool) error { if err := ca.sink.Flush(ca.Ctx); err != nil { return err } - ca.lastFlush = timeutil.Now() - // Iterate the spans in reverse so that if there are a very large number of - // spans which we're propagating upwards get processed in newest to oldest - // order. This will ultimately improve the efficiency of the checkpointing - // code which wants to checkpoint whenever the frontier changes. - for i := len(ca.spansToFlush) - 1; i >= 0; i-- { - resolvedBytes, err := protoutil.Marshal(ca.spansToFlush[i]) + // Iterate spans that have updated timestamp ahead of the last flush timestamp and + // emit resolved span records. + var err error + ca.frontier.UpdatedEntries(ca.lastFlush, func(s roachpb.Span, ts hlc.Timestamp) span.OpResult { + err = ca.emitResolved(s, ts, ca.frontier.boundaryTypeAt(ts)) if err != nil { - return err + return span.StopMatch } - // Enqueue a row to be returned that indicates some span-level resolved - // timestamp has advanced. If any rows were queued in `sink`, they must - // be emitted first. - ca.resolvedSpanBuf.Push(rowenc.EncDatumRow{ - rowenc.EncDatum{Datum: tree.NewDBytes(tree.DBytes(resolvedBytes))}, - rowenc.EncDatum{Datum: tree.DNull}, // topic - rowenc.EncDatum{Datum: tree.DNull}, // key - rowenc.EncDatum{Datum: tree.DNull}, // value - }) - } - ca.spansToFlush = ca.spansToFlush[:0] + return span.ContinueMatch + }) + + if err != nil { + return err + } + + ca.lastFlush = ca.clock.Now() + ca.numFrontierUpdates = 0 + return nil +} + +func (ca *changeAggregator) emitResolved( + s roachpb.Span, ts hlc.Timestamp, boundary jobspb.ResolvedSpan_BoundaryType, +) error { + var resolvedSpan jobspb.ResolvedSpan + resolvedSpan.Span = s + resolvedSpan.Timestamp = ts + resolvedSpan.BoundaryType = boundary + + resolvedBytes, err := protoutil.Marshal(&resolvedSpan) + if err != nil { + return err + } + // Enqueue a row to be returned that indicates some span-level resolved + // timestamp has advanced. If any rows were queued in `sink`, they must + // be emitted first. + ca.resolvedSpanBuf.Push(rowenc.EncDatumRow{ + rowenc.EncDatum{Datum: tree.NewDBytes(tree.DBytes(resolvedBytes))}, + rowenc.EncDatum{Datum: tree.DNull}, // topic + rowenc.EncDatum{Datum: tree.DNull}, // key + rowenc.EncDatum{Datum: tree.DNull}, // value + }) return nil } @@ -886,9 +919,9 @@ type changeFrontier struct { // input returns rows from one or more changeAggregator processors input execinfra.RowSource - // sf contains the current resolved timestamp high-water for the tracked + // frontier contains the current resolved timestamp high-water for the tracked // span set. - sf *span.Frontier + frontier *schemaChangeFrontier // encoder is the Encoder to use for resolved timestamp serialization. encoder Encoder // sink is the Sink to write resolved timestamps to. Rows are never written @@ -903,26 +936,6 @@ type changeFrontier struct { // slowLogEveryN rate-limits the logging of slow spans slowLogEveryN log.EveryN - // schemaChangeBoundary represents an hlc timestamp at which a schema change - // event occurred to a target watched by this frontier. If the changefeed is - // configured to stop on schema change then the changeFrontier will wait for - // the span frontier to reach the schemaChangeBoundary, will drain, and then - // will exit. If the changefeed is configured to backfill on schema changes, - // the changeFrontier will protect the scan timestamp in order to ensure that - // the scan complete. The protected timestamp will be released when a new scan - // schemaChangeBoundary is created or the changefeed reaches a timestamp that - // is near the present. - // - // schemaChangeBoundary values are communicated to the changeFrontier via - // Resolved messages send from the changeAggregators. The policy regarding - // which schema change events lead to a schemaChangeBoundary is controlled - // by the KV feed based on OptSchemaChangeEvents and OptSchemaChangePolicy. - schemaChangeBoundary hlc.Timestamp - - // boundaryType indicates the type of the schemaChangeBoundary and thus the - // action which should be taken when the frontier reaches that boundary. - boundaryType jobspb.ResolvedSpan_BoundaryType - // js, if non-nil, is called to checkpoint the changefeed's // progress in the corresponding system job entry. js *jobState @@ -967,7 +980,7 @@ func newChangeFrontierProcessor( ) (execinfra.Processor, error) { ctx := flowCtx.EvalCtx.Ctx() memMonitor := execinfra.NewMonitor(ctx, flowCtx.EvalCtx.Mon, "changefntr-mem") - sf, err := span.MakeFrontier(spec.TrackedSpans...) + sf, err := makeSchemaChangeFrontier(flowCtx.Cfg.DB.Clock(), spec.TrackedSpans...) if err != nil { return nil, err } @@ -976,7 +989,7 @@ func newChangeFrontierProcessor( spec: spec, memAcc: memMonitor.MakeBoundAccount(), input: input, - sf: sf, + frontier: sf, slowLogEveryN: log.Every(slowSpanMaxFrequency), } if err := cf.Init( @@ -1125,12 +1138,6 @@ func (cf *changeFrontier) closeMetrics() { cf.metrics.mu.Unlock() } -// schemaChangeBoundaryReached returns true if the spanFrontier is at the -// current schemaChangeBoundary. -func (cf *changeFrontier) schemaChangeBoundaryReached() (r bool) { - return !cf.schemaChangeBoundary.IsEmpty() && cf.schemaChangeBoundary.Equal(cf.sf.Frontier()) -} - // shouldProtectBoundaries checks the job's spec to determine whether it should // install protected timestamps when encountering scan boundaries. func (cf *changeFrontier) shouldProtectBoundaries() bool { @@ -1147,15 +1154,15 @@ func (cf *changeFrontier) Next() (rowenc.EncDatumRow, *execinfrapb.ProducerMetad return cf.ProcessRowHelper(cf.resolvedBuf.Pop()), nil } - if cf.schemaChangeBoundaryReached() && - (cf.boundaryType == jobspb.ResolvedSpan_EXIT || - cf.boundaryType == jobspb.ResolvedSpan_RESTART) { + if cf.frontier.schemaChangeBoundaryReached() && + (cf.frontier.boundaryType == jobspb.ResolvedSpan_EXIT || + cf.frontier.boundaryType == jobspb.ResolvedSpan_RESTART) { err := pgerror.Newf(pgcode.SchemaChangeOccurred, - "schema change occurred at %v", cf.schemaChangeBoundary.Next().AsOfSystemTime()) + "schema change occurred at %v", cf.frontier.boundaryTime.Next().AsOfSystemTime()) // Detect whether this boundary should be used to kill or restart the // changefeed. - if cf.boundaryType == jobspb.ResolvedSpan_RESTART { + if cf.frontier.boundaryType == jobspb.ResolvedSpan_RESTART { // The code to restart the changefeed is only supported once 21.1 is // activated. // @@ -1232,27 +1239,7 @@ func (cf *changeFrontier) noteResolvedSpan(d rowenc.EncDatum) error { return nil } - // We want to ensure that we mark the schemaChangeBoundary and then we want - // to detect when the frontier reaches to or past the schemaChangeBoundary. - // The behavior when the boundary is reached is controlled by the - // boundaryType. - switch resolved.BoundaryType { - case jobspb.ResolvedSpan_NONE: - if !cf.schemaChangeBoundary.IsEmpty() && cf.schemaChangeBoundary.Less(resolved.Timestamp) { - cf.schemaChangeBoundary = hlc.Timestamp{} - cf.boundaryType = jobspb.ResolvedSpan_NONE - } - case jobspb.ResolvedSpan_BACKFILL, jobspb.ResolvedSpan_EXIT, jobspb.ResolvedSpan_RESTART: - if !cf.schemaChangeBoundary.IsEmpty() && resolved.Timestamp.Less(cf.schemaChangeBoundary) { - return errors.AssertionFailedf("received boundary timestamp %v < %v "+ - "of type %v before reaching existing boundary of type %v", - resolved.Timestamp, cf.schemaChangeBoundary, resolved.BoundaryType, cf.boundaryType) - } - cf.schemaChangeBoundary = resolved.Timestamp - cf.boundaryType = resolved.BoundaryType - } - - frontierChanged, err := cf.sf.Forward(resolved.Span, resolved.Timestamp) + frontierChanged, err := cf.frontier.ForwardResolvedSpan(resolved) if err != nil { return err } @@ -1266,7 +1253,7 @@ func (cf *changeFrontier) noteResolvedSpan(d rowenc.EncDatum) error { } func (cf *changeFrontier) handleFrontierChanged(isBehind bool) error { - newResolved := cf.sf.Frontier() + newResolved := cf.frontier.Frontier() cf.metrics.mu.Lock() if cf.metricsID != -1 { cf.metrics.mu.resolved[cf.metricsID] = newResolved @@ -1379,7 +1366,7 @@ func (cf *changeFrontier) maybeReleaseProtectedTimestamp( if progress.ProtectedTimestampRecord == uuid.Nil { return nil } - if !cf.schemaChangeBoundaryReached() && isBehind { + if !cf.frontier.schemaChangeBoundaryReached() && isBehind { log.VEventf(ctx, 2, "not releasing protected timestamp because changefeed is behind") return nil } @@ -1402,7 +1389,7 @@ func (cf *changeFrontier) maybeProtectTimestamp( txn *kv.Txn, resolved hlc.Timestamp, ) error { - if cf.isSinkless() || !cf.schemaChangeBoundaryReached() || !cf.shouldProtectBoundaries() { + if cf.isSinkless() || !cf.frontier.schemaChangeBoundaryReached() || !cf.shouldProtectBoundaries() { return nil } @@ -1416,7 +1403,7 @@ func (cf *changeFrontier) maybeEmitResolved(newResolved hlc.Timestamp) error { return nil } sinceEmitted := newResolved.GoTime().Sub(cf.lastEmitResolved) - shouldEmit := sinceEmitted >= cf.freqEmitResolved || cf.schemaChangeBoundaryReached() + shouldEmit := sinceEmitted >= cf.freqEmitResolved || cf.frontier.schemaChangeBoundaryReached() if !shouldEmit { return nil } @@ -1433,7 +1420,7 @@ func (cf *changeFrontier) maybeEmitResolved(newResolved hlc.Timestamp) error { // returned boolean will be true if the resolved timestamp lags far behind the // present as defined by the current configuration. func (cf *changeFrontier) maybeLogBehindSpan(frontierChanged bool) (isBehind bool) { - frontier := cf.sf.Frontier() + frontier := cf.frontier.Frontier() now := timeutil.Now() resolvedBehind := now.Sub(frontier.GoTime()) if resolvedBehind <= cf.slownessThreshold() { @@ -1450,7 +1437,7 @@ func (cf *changeFrontier) maybeLogBehindSpan(frontierChanged bool) (isBehind boo } if cf.slowLogEveryN.ShouldProcess(now) { - s := cf.sf.PeekFrontierSpan() + s := cf.frontier.PeekFrontierSpan() log.Infof(cf.Ctx, "%s span %s is behind by %s", description, s, resolvedBehind) } return true @@ -1486,3 +1473,98 @@ func (cf *changeFrontier) ConsumerClosed() { func (cf *changeFrontier) isSinkless() bool { return cf.spec.JobID == 0 } + +// type to make embedding span.Frontier in schemaChangeFrontier convenient. +type spanFrontier struct { + *span.Frontier +} + +func (s *spanFrontier) frontierTimestamp() hlc.Timestamp { + return s.Frontier.Frontier() +} + +// schemaChangeFrontier encapsulates span frontier, keeping track of span resolved time, +// along with the schema change boundary information. +type schemaChangeFrontier struct { + *spanFrontier + + // boundaryTime represents an hlc timestamp at which a schema change + // event occurred to a target watched by this frontier. If the changefeed is + // configured to stop on schema change then the changeFrontier will wait for + // the span frontier to reach the schemaChangeBoundary, will drain, and then + // will exit. If the changefeed is configured to backfill on schema changes, + // the changeFrontier will protect the scan timestamp in order to ensure that + // the scan complete. The protected timestamp will be released when a new scan + // schemaChangeBoundary is created or the changefeed reaches a timestamp that + // is near the present. + // + // schemaChangeBoundary values are communicated to the changeFrontier via + // Resolved messages send from the changeAggregators. The policy regarding + // which schema change events lead to a schemaChangeBoundary is controlled + // by the KV feed based on OptSchemaChangeEvents and OptSchemaChangePolicy. + boundaryTime hlc.Timestamp + + // boundaryType indicates the type of the schemaChangeBoundary and thus the + // action which should be taken when the frontier reaches that boundary. + boundaryType jobspb.ResolvedSpan_BoundaryType +} + +func makeSchemaChangeFrontier(c *hlc.Clock, spans ...roachpb.Span) (*schemaChangeFrontier, error) { + sf, err := span.MakeFrontier(spans...) + if err != nil { + return nil, err + } + f := &schemaChangeFrontier{spanFrontier: &spanFrontier{sf}} + f.spanFrontier.TrackUpdateTimestamp(func() hlc.Timestamp { return c.Now() }) + return f, nil +} + +// ForwardResolvedSpan advances the timestamp for a resolved span. +// Takes care of updating schema change boundary information. +func (f *schemaChangeFrontier) ForwardResolvedSpan(r jobspb.ResolvedSpan) (bool, error) { + // We want to ensure that we mark the schemaChangeBoundary and then we want + // to detect when the frontier reaches to or past the schemaChangeBoundary. + // The behavior when the boundary is reached is controlled by the + // boundary type. + // NB: boundaryType and time update machinery is tricky. In particular, + // we never go back to ResolvedSpan_None if we have seen newer boundary other than none. + switch r.BoundaryType { + case jobspb.ResolvedSpan_NONE: + if !f.boundaryTime.IsEmpty() && f.boundaryTime.Less(r.Timestamp) { + f.boundaryTime = hlc.Timestamp{} + f.boundaryType = jobspb.ResolvedSpan_NONE + } + case jobspb.ResolvedSpan_BACKFILL, jobspb.ResolvedSpan_EXIT, jobspb.ResolvedSpan_RESTART: + if !f.boundaryTime.IsEmpty() && r.Timestamp.Less(f.boundaryTime) { + return false, errors.AssertionFailedf("received boundary timestamp %v < %v "+ + "of type %v before reaching existing boundary of type %v", + r.Timestamp, f.boundaryTime, r.BoundaryType, f.boundaryType) + } + f.boundaryTime = r.Timestamp + f.boundaryType = r.BoundaryType + } + return f.Forward(r.Span, r.Timestamp) +} + +// Frontier returns the minimum timestamp being tracked. +func (f *schemaChangeFrontier) Frontier() hlc.Timestamp { + return f.frontierTimestamp() +} + +// SpanFrontier returns underlying span.Frontier. +func (f *schemaChangeFrontier) SpanFrontier() *span.Frontier { + return f.spanFrontier.Frontier +} + +// schemaChangeBoundaryReached returns true if the schema change boundary has been reached. +func (f *schemaChangeFrontier) schemaChangeBoundaryReached() (r bool) { + return !f.boundaryTime.IsEmpty() && f.boundaryTime.Equal(f.Frontier()) +} + +// boundaryTypeAt returns boundary type applicable at the specified timestamp. +func (f *schemaChangeFrontier) boundaryTypeAt(ts hlc.Timestamp) jobspb.ResolvedSpan_BoundaryType { + if f.boundaryTime.IsEmpty() || ts.Less(f.boundaryTime) { + return jobspb.ResolvedSpan_NONE + } + return f.boundaryType +} diff --git a/pkg/ccl/changefeedccl/schemafeed/schema_feed.go b/pkg/ccl/changefeedccl/schemafeed/schema_feed.go index 628229e39bf0..a0366f99d746 100644 --- a/pkg/ccl/changefeedccl/schemafeed/schema_feed.go +++ b/pkg/ccl/changefeedccl/schemafeed/schema_feed.go @@ -178,16 +178,27 @@ func (t *typeDependencyTracker) removeDependency(typeID, tableID descpb.ID) { } } -func (t *typeDependencyTracker) purgeTable(tbl catalog.TableDescriptor) { +func (t *typeDependencyTracker) purgeTable(tbl catalog.TableDescriptor) error { for _, col := range tbl.UserDefinedTypeColumns() { - t.removeDependency(typedesc.UserDefinedTypeOIDToID(col.GetType().Oid()), tbl.GetID()) + id, err := typedesc.UserDefinedTypeOIDToID(col.GetType().Oid()) + if err != nil { + return err + } + t.removeDependency(id, tbl.GetID()) } + + return nil } -func (t *typeDependencyTracker) ingestTable(tbl catalog.TableDescriptor) { +func (t *typeDependencyTracker) ingestTable(tbl catalog.TableDescriptor) error { for _, col := range tbl.UserDefinedTypeColumns() { - t.addDependency(typedesc.UserDefinedTypeOIDToID(col.GetType().Oid()), tbl.GetID()) + id, err := typedesc.UserDefinedTypeOIDToID(col.GetType().Oid()) + if err != nil { + return err + } + t.addDependency(id, tbl.GetID()) } + return nil } func (t *typeDependencyTracker) containsType(id descpb.ID) bool { @@ -289,7 +300,10 @@ func (tf *SchemaFeed) primeInitialTableDescs(ctx context.Context) error { // Register all types used by the initial set of tables. for _, desc := range initialDescs { tbl := desc.(catalog.TableDescriptor) - tf.mu.typeDeps.ingestTable(tbl) + if err := tf.mu.typeDeps.ingestTable(tbl); err != nil { + tf.mu.Unlock() + return err + } } tf.mu.Unlock() @@ -533,7 +547,9 @@ func (tf *SchemaFeed) validateDescriptor( } // Purge the old version of the table from the type mapping. - tf.mu.typeDeps.purgeTable(lastVersion) + if err := tf.mu.typeDeps.purgeTable(lastVersion); err != nil { + return err + } e := TableEvent{ Before: lastVersion, @@ -559,7 +575,9 @@ func (tf *SchemaFeed) validateDescriptor( } } // Add the types used by the table into the dependency tracker. - tf.mu.typeDeps.ingestTable(desc) + if err := tf.mu.typeDeps.ingestTable(desc); err != nil { + return err + } tf.mu.previousTableVersion[desc.GetID()] = desc return nil default: diff --git a/pkg/sql/BUILD.bazel b/pkg/sql/BUILD.bazel index a595af981f85..85aa241635ae 100644 --- a/pkg/sql/BUILD.bazel +++ b/pkg/sql/BUILD.bazel @@ -302,7 +302,6 @@ go_library( "//pkg/sql/inverted", "//pkg/sql/lex", "//pkg/sql/mutations", - "//pkg/sql/oidext", "//pkg/sql/opt", "//pkg/sql/opt/cat", "//pkg/sql/opt/constraint", diff --git a/pkg/sql/catalog/dbdesc/database_desc.go b/pkg/sql/catalog/dbdesc/database_desc.go index b083b89817d3..6d26ac76d701 100644 --- a/pkg/sql/catalog/dbdesc/database_desc.go +++ b/pkg/sql/catalog/dbdesc/database_desc.go @@ -225,15 +225,19 @@ func (desc *immutable) validateMultiRegion(vea catalog.ValidationErrorAccumulato // GetReferencedDescIDs returns the IDs of all descriptors referenced by // this descriptor, including itself. -func (desc *immutable) GetReferencedDescIDs() catalog.DescriptorIDSet { +func (desc *immutable) GetReferencedDescIDs() (catalog.DescriptorIDSet, error) { ids := catalog.MakeDescriptorIDSet(desc.GetID()) - if id, err := desc.MultiRegionEnumID(); err == nil { + if desc.IsMultiRegion() { + id, err := desc.MultiRegionEnumID() + if err != nil { + return catalog.DescriptorIDSet{}, err + } ids.Add(id) } for _, schema := range desc.Schemas { ids.Add(schema.ID) } - return ids + return ids, nil } // ValidateCrossReferences implements the catalog.Descriptor interface. diff --git a/pkg/sql/catalog/descriptor.go b/pkg/sql/catalog/descriptor.go index 3d22ff81bf91..1b07c5f632d2 100644 --- a/pkg/sql/catalog/descriptor.go +++ b/pkg/sql/catalog/descriptor.go @@ -140,7 +140,7 @@ type Descriptor interface { // GetReferencedDescIDs returns the IDs of all descriptors directly referenced // by this descriptor, including itself. - GetReferencedDescIDs() DescriptorIDSet + GetReferencedDescIDs() (DescriptorIDSet, error) // ValidateSelf checks the internal consistency of the descriptor. ValidateSelf(vea ValidationErrorAccumulator) @@ -358,7 +358,7 @@ type TypeDescriptor interface { HydrateTypeInfoWithName(ctx context.Context, typ *types.T, name *tree.TypeName, res TypeDescriptorResolver) error MakeTypesT(ctx context.Context, name *tree.TypeName, res TypeDescriptorResolver) (*types.T, error) HasPendingSchemaChanges() bool - GetIDClosure() map[descpb.ID]struct{} + GetIDClosure() (map[descpb.ID]struct{}, error) IsCompatibleWith(other TypeDescriptor) error PrimaryRegionName() (descpb.RegionName, error) diff --git a/pkg/sql/catalog/descs/collection.go b/pkg/sql/catalog/descs/collection.go index b1501e8f7b70..371044db6554 100644 --- a/pkg/sql/catalog/descs/collection.go +++ b/pkg/sql/catalog/descs/collection.go @@ -2180,7 +2180,11 @@ func (dt DistSQLTypeResolver) ResolveType( // ResolveTypeByOID implements the tree.TypeReferenceResolver interface. func (dt DistSQLTypeResolver) ResolveTypeByOID(ctx context.Context, oid oid.Oid) (*types.T, error) { - name, desc, err := dt.GetTypeDescriptor(ctx, typedesc.UserDefinedTypeOIDToID(oid)) + id, err := typedesc.UserDefinedTypeOIDToID(oid) + if err != nil { + return nil, err + } + name, desc, err := dt.GetTypeDescriptor(ctx, id) if err != nil { return nil, err } @@ -2213,7 +2217,11 @@ func (dt DistSQLTypeResolver) GetTypeDescriptor( func (dt DistSQLTypeResolver) HydrateTypeSlice(ctx context.Context, typs []*types.T) error { for _, t := range typs { if t.UserDefined() { - name, desc, err := dt.GetTypeDescriptor(ctx, typedesc.GetTypeDescID(t)) + id, err := typedesc.GetUserDefinedTypeDescID(t) + if err != nil { + return err + } + name, desc, err := dt.GetTypeDescriptor(ctx, id) if err != nil { return err } diff --git a/pkg/sql/catalog/schemadesc/schema_desc.go b/pkg/sql/catalog/schemadesc/schema_desc.go index 27701353c397..abadad7655f3 100644 --- a/pkg/sql/catalog/schemadesc/schema_desc.go +++ b/pkg/sql/catalog/schemadesc/schema_desc.go @@ -146,8 +146,8 @@ func (desc *immutable) ValidateSelf(vea catalog.ValidationErrorAccumulator) { // GetReferencedDescIDs returns the IDs of all descriptors referenced by // this descriptor, including itself. -func (desc *immutable) GetReferencedDescIDs() catalog.DescriptorIDSet { - return catalog.MakeDescriptorIDSet(desc.GetID(), desc.GetParentID()) +func (desc *immutable) GetReferencedDescIDs() (catalog.DescriptorIDSet, error) { + return catalog.MakeDescriptorIDSet(desc.GetID(), desc.GetParentID()), nil } // ValidateCrossReferences implements the catalog.Descriptor interface. diff --git a/pkg/sql/catalog/tabledesc/structured.go b/pkg/sql/catalog/tabledesc/structured.go index aa218aa0cfce..123c84693cf6 100644 --- a/pkg/sql/catalog/tabledesc/structured.go +++ b/pkg/sql/catalog/tabledesc/structured.go @@ -500,27 +500,44 @@ func (desc *wrapper) getAllReferencedTypesInTableColumns( // collect the closure of ID's referenced. ids := make(map[descpb.ID]struct{}) for id := range visitor.OIDs { - typDesc, err := getType(typedesc.UserDefinedTypeOIDToID(id)) + uid, err := typedesc.UserDefinedTypeOIDToID(id) if err != nil { return nil, err } - for child := range typDesc.GetIDClosure() { + typDesc, err := getType(uid) + if err != nil { + return nil, err + } + children, err := typDesc.GetIDClosure() + if err != nil { + return nil, err + } + for child := range children { ids[child] = struct{}{} } } // Now add all of the column types in the table. - addIDsInColumn := func(c *descpb.ColumnDescriptor) { - for id := range typedesc.GetTypeDescriptorClosure(c.Type) { + addIDsInColumn := func(c *descpb.ColumnDescriptor) error { + children, err := typedesc.GetTypeDescriptorClosure(c.Type) + if err != nil { + return err + } + for id := range children { ids[id] = struct{}{} } + return nil } for i := range desc.Columns { - addIDsInColumn(&desc.Columns[i]) + if err := addIDsInColumn(&desc.Columns[i]); err != nil { + return nil, err + } } for _, mut := range desc.Mutations { if c := mut.GetColumn(); c != nil { - addIDsInColumn(c) + if err := addIDsInColumn(c); err != nil { + return nil, err + } } } diff --git a/pkg/sql/catalog/tabledesc/validate.go b/pkg/sql/catalog/tabledesc/validate.go index 3ba8c901b617..79400ba9c084 100644 --- a/pkg/sql/catalog/tabledesc/validate.go +++ b/pkg/sql/catalog/tabledesc/validate.go @@ -47,7 +47,7 @@ func (desc *wrapper) ValidateTxnCommit( // GetReferencedDescIDs returns the IDs of all descriptors referenced by // this descriptor, including itself. -func (desc *wrapper) GetReferencedDescIDs() catalog.DescriptorIDSet { +func (desc *wrapper) GetReferencedDescIDs() (catalog.DescriptorIDSet, error) { ids := catalog.MakeDescriptorIDSet(desc.GetID(), desc.GetParentID()) if desc.GetParentSchemaID() != keys.PublicSchemaID { ids.Add(desc.GetParentSchemaID()) @@ -69,7 +69,11 @@ func (desc *wrapper) GetReferencedDescIDs() catalog.DescriptorIDSet { } // Collect user defined type Oids and sequence references in columns. for _, col := range desc.DeletableColumns() { - for id := range typedesc.GetTypeDescriptorClosure(col.GetType()) { + children, err := typedesc.GetTypeDescriptorClosure(col.GetType()) + if err != nil { + return catalog.DescriptorIDSet{}, err + } + for id := range children { ids.Add(id) } for i := 0; i < col.NumUsesSequences(); i++ { @@ -89,7 +93,11 @@ func (desc *wrapper) GetReferencedDescIDs() catalog.DescriptorIDSet { }) // Add collected Oids to return set. for oid := range visitor.OIDs { - ids.Add(typedesc.UserDefinedTypeOIDToID(oid)) + id, err := typedesc.UserDefinedTypeOIDToID(oid) + if err != nil { + return catalog.DescriptorIDSet{}, err + } + ids.Add(id) } // Add view dependencies. for _, id := range desc.GetDependsOn() { @@ -102,7 +110,7 @@ func (desc *wrapper) GetReferencedDescIDs() catalog.DescriptorIDSet { ids.Add(ref.ID) } // Add sequence dependencies - return ids + return ids, nil } // ValidateCrossReferences validates that each reference to another table is diff --git a/pkg/sql/catalog/typedesc/BUILD.bazel b/pkg/sql/catalog/typedesc/BUILD.bazel index 10b9127b8ef1..a03794ab256c 100644 --- a/pkg/sql/catalog/typedesc/BUILD.bazel +++ b/pkg/sql/catalog/typedesc/BUILD.bazel @@ -46,6 +46,7 @@ go_test( "//pkg/sql/catalog/dbdesc", "//pkg/sql/catalog/descpb", "//pkg/sql/catalog/schemadesc", + "//pkg/sql/oidext", "//pkg/sql/privilege", "//pkg/sql/types", "//pkg/testutils", @@ -53,6 +54,7 @@ go_test( "//pkg/util/leaktest", "//pkg/util/randutil", "@com_github_cockroachdb_redact//:redact", + "@com_github_lib_pq//oid", "@com_github_stretchr_testify//require", "@in_gopkg_yaml_v2//:yaml_v2", ], diff --git a/pkg/sql/catalog/typedesc/type_desc.go b/pkg/sql/catalog/typedesc/type_desc.go index 5bb2ee0760fc..17c370ae42f9 100644 --- a/pkg/sql/catalog/typedesc/type_desc.go +++ b/pkg/sql/catalog/typedesc/type_desc.go @@ -114,19 +114,25 @@ func TypeIDToOID(id descpb.ID) oid.Oid { } // UserDefinedTypeOIDToID converts a user defined type OID into a -// descriptor ID. -func UserDefinedTypeOIDToID(oid oid.Oid) descpb.ID { - return descpb.ID(oid) - oidext.CockroachPredefinedOIDMax +// descriptor ID. OID of a user-defined type must be greater than +// CockroachPredefinedOIDMax. The function returns an error if the +// given OID is less than or equals to CockroachPredefinedMax. +func UserDefinedTypeOIDToID(oid oid.Oid) (descpb.ID, error) { + if descpb.ID(oid) <= oidext.CockroachPredefinedOIDMax { + return 0, errors.Newf("user-defined OID %d should be greater "+ + "than predefined Max: %d.", oid, oidext.CockroachPredefinedOIDMax) + } + return descpb.ID(oid) - oidext.CockroachPredefinedOIDMax, nil } -// GetTypeDescID gets the type descriptor ID from a user defined type. -func GetTypeDescID(t *types.T) descpb.ID { +// GetUserDefinedTypeDescID gets the type descriptor ID from a user defined type. +func GetUserDefinedTypeDescID(t *types.T) (descpb.ID, error) { return UserDefinedTypeOIDToID(t.Oid()) } -// GetArrayTypeDescID gets the ID of the array type descriptor from a user +// GetUserDefinedArrayTypeDescID gets the ID of the array type descriptor from a user // defined type. -func GetArrayTypeDescID(t *types.T) descpb.ID { +func GetUserDefinedArrayTypeDescID(t *types.T) (descpb.ID, error) { return UserDefinedTypeOIDToID(t.UserDefinedArrayOID()) } @@ -554,16 +560,20 @@ func (desc *immutable) validateEnumMembers(vea catalog.ValidationErrorAccumulato // GetReferencedDescIDs returns the IDs of all descriptors referenced by // this descriptor, including itself. -func (desc *immutable) GetReferencedDescIDs() catalog.DescriptorIDSet { +func (desc *immutable) GetReferencedDescIDs() (catalog.DescriptorIDSet, error) { ids := catalog.MakeDescriptorIDSet(desc.GetReferencingDescriptorIDs()...) ids.Add(desc.GetParentID()) if desc.GetParentSchemaID() != keys.PublicSchemaID { ids.Add(desc.GetParentSchemaID()) } - for id := range desc.GetIDClosure() { + children, err := desc.GetIDClosure() + if err != nil { + return catalog.DescriptorIDSet{}, err + } + for id := range children { ids.Add(id) } - return ids + return ids, nil } // ValidateCrossReferences performs cross reference checks on the type descriptor. @@ -599,7 +609,10 @@ func (desc *immutable) ValidateCrossReferences( } case descpb.TypeDescriptor_ALIAS: if desc.GetAlias().UserDefined() { - aliasedID := UserDefinedTypeOIDToID(desc.GetAlias().Oid()) + aliasedID, err := UserDefinedTypeOIDToID(desc.GetAlias().Oid()) + if err != nil { + vea.Report(err) + } if _, err := vdg.GetTypeDescriptor(aliasedID); err != nil { vea.Report(errors.Wrapf(err, "aliased type %d does not exist", aliasedID)) } @@ -724,7 +737,11 @@ func HydrateTypesInTableDescriptor( hydrateCol := func(col *descpb.ColumnDescriptor) error { if col.Type.UserDefined() { // Look up its type descriptor. - name, typDesc, err := res.GetTypeDescriptor(ctx, GetTypeDescID(col.Type)) + td, err := GetUserDefinedTypeDescID(col.Type) + if err != nil { + return err + } + name, typDesc, err := res.GetTypeDescriptor(ctx, td) if err != nil { return err } @@ -787,7 +804,11 @@ func (desc *immutable) HydrateTypeInfoWithName( case types.ArrayFamily: // Hydrate the element type. elemType := typ.ArrayContents() - elemTypName, elemTypDesc, err := res.GetTypeDescriptor(ctx, GetTypeDescID(elemType)) + id, err := GetUserDefinedTypeDescID(elemType) + if err != nil { + return err + } + elemTypName, elemTypDesc, err := res.GetTypeDescriptor(ctx, id) if err != nil { return err } @@ -901,14 +922,17 @@ func (desc *immutable) HasPendingSchemaChanges() bool { // GetIDClosure returns all type descriptor IDs that are referenced by this // type descriptor. -func (desc *immutable) GetIDClosure() map[descpb.ID]struct{} { +func (desc *immutable) GetIDClosure() (map[descpb.ID]struct{}, error) { ret := make(map[descpb.ID]struct{}) // Collect the descriptor's own ID. ret[desc.ID] = struct{}{} if desc.Kind == descpb.TypeDescriptor_ALIAS { // If this descriptor is an alias for another type, then get collect the // closure for alias. - children := GetTypeDescriptorClosure(desc.Alias) + children, err := GetTypeDescriptorClosure(desc.Alias) + if err != nil { + return nil, err + } for id := range children { ret[id] = struct{}{} } @@ -916,28 +940,39 @@ func (desc *immutable) GetIDClosure() map[descpb.ID]struct{} { // Otherwise, take the array type ID. ret[desc.ArrayTypeID] = struct{}{} } - return ret + return ret, nil } // GetTypeDescriptorClosure returns all type descriptor IDs that are // referenced by this input types.T. -func GetTypeDescriptorClosure(typ *types.T) map[descpb.ID]struct{} { +func GetTypeDescriptorClosure(typ *types.T) (map[descpb.ID]struct{}, error) { if !typ.UserDefined() { - return map[descpb.ID]struct{}{} + return map[descpb.ID]struct{}{}, nil + } + id, err := GetUserDefinedTypeDescID(typ) + if err != nil { + return nil, err } // Collect the type's descriptor ID. ret := map[descpb.ID]struct{}{ - GetTypeDescID(typ): {}, + id: {}, } if typ.Family() == types.ArrayFamily { // If we have an array type, then collect all types in the contents. - children := GetTypeDescriptorClosure(typ.ArrayContents()) + children, err := GetTypeDescriptorClosure(typ.ArrayContents()) + if err != nil { + return nil, err + } for id := range children { ret[id] = struct{}{} } } else { // Otherwise, take the array type ID. - ret[GetArrayTypeDescID(typ)] = struct{}{} + id, err := GetUserDefinedArrayTypeDescID(typ) + if err != nil { + return nil, err + } + ret[id] = struct{}{} } - return ret + return ret, nil } diff --git a/pkg/sql/catalog/typedesc/type_desc_test.go b/pkg/sql/catalog/typedesc/type_desc_test.go index 91e32be2502c..0a1a495f16ba 100644 --- a/pkg/sql/catalog/typedesc/type_desc_test.go +++ b/pkg/sql/catalog/typedesc/type_desc_test.go @@ -13,6 +13,7 @@ package typedesc_test import ( "context" "fmt" + "math" "testing" "github.com/cockroachdb/cockroach/pkg/keys" @@ -22,10 +23,12 @@ import ( "github.com/cockroachdb/cockroach/pkg/sql/catalog/descpb" "github.com/cockroachdb/cockroach/pkg/sql/catalog/schemadesc" "github.com/cockroachdb/cockroach/pkg/sql/catalog/typedesc" + "github.com/cockroachdb/cockroach/pkg/sql/oidext" "github.com/cockroachdb/cockroach/pkg/sql/privilege" "github.com/cockroachdb/cockroach/pkg/sql/types" "github.com/cockroachdb/cockroach/pkg/testutils" "github.com/cockroachdb/cockroach/pkg/util/leaktest" + "github.com/lib/pq/oid" "github.com/stretchr/testify/require" ) @@ -784,3 +787,28 @@ func TestValidateTypeDesc(t *testing.T) { } } } + +func TestOIDToIDConversion(t *testing.T) { + tests := []struct { + oid oid.Oid + ok bool + name string + }{ + {oid.Oid(0), false, "default OID"}, + {oid.Oid(1), false, "Standard OID"}, + {oid.Oid(oidext.CockroachPredefinedOIDMax), false, "max standard OID"}, + {oid.Oid(oidext.CockroachPredefinedOIDMax + 1), true, "user-defined OID"}, + {oid.Oid(math.MaxUint32), true, "max user-defined OID"}, + } + + for _, test := range tests { + t.Run(fmt.Sprint(test.oid), func(t *testing.T) { + _, err := typedesc.UserDefinedTypeOIDToID(test.oid) + if test.ok { + require.NoError(t, err) + } else { + require.Error(t, err) + } + }) + } +} diff --git a/pkg/sql/catalog/validate.go b/pkg/sql/catalog/validate.go index 3ad024fbfc7c..bb2b6e846a2d 100644 --- a/pkg/sql/catalog/validate.go +++ b/pkg/sql/catalog/validate.go @@ -437,9 +437,14 @@ type collectorState struct { } // addDirectReferences adds all immediate neighbors of desc to the state. -func (cs *collectorState) addDirectReferences(desc Descriptor) { +func (cs *collectorState) addDirectReferences(desc Descriptor) error { cs.vdg.Descriptors[desc.GetID()] = desc - desc.GetReferencedDescIDs().ForEach(cs.referencedBy.Add) + idSet, err := desc.GetReferencedDescIDs() + if err != nil { + return err + } + idSet.ForEach(cs.referencedBy.Add) + return nil } // getMissingDescs fetches the descriptors which have corresponding IDs in the @@ -491,7 +496,9 @@ func collectDescriptorsForValidation( referencedBy: MakeDescriptorIDSet(), } for _, desc := range descriptors { - cs.addDirectReferences(desc) + if err := cs.addDirectReferences(desc); err != nil { + return nil, err + } } newDescs, err := cs.getMissingDescs(ctx, maybeBatchDescGetter) if err != nil { @@ -503,7 +510,9 @@ func collectDescriptorsForValidation( } switch newDesc.(type) { case DatabaseDescriptor, TypeDescriptor: - cs.addDirectReferences(newDesc) + if err := cs.addDirectReferences(newDesc); err != nil { + return nil, err + } } } _, err = cs.getMissingDescs(ctx, maybeBatchDescGetter) diff --git a/pkg/sql/database_region_change_finalizer.go b/pkg/sql/database_region_change_finalizer.go index 1b7560c0f326..e911f2eddb64 100644 --- a/pkg/sql/database_region_change_finalizer.go +++ b/pkg/sql/database_region_change_finalizer.go @@ -169,8 +169,14 @@ func (r *databaseRegionChangeFinalizer) repartitionRegionalByRowTables( // the table descriptor with the new type metadata. for i := range tableDesc.Columns { col := &tableDesc.Columns[i] - if col.Type.UserDefined() && typedesc.UserDefinedTypeOIDToID(col.Type.Oid()) == r.typeID { - col.Type.TypeMeta = types.UserDefinedTypeMetadata{} + if col.Type.UserDefined() { + tid, err := typedesc.UserDefinedTypeOIDToID(col.Type.Oid()) + if err != nil { + return err + } + if tid == r.typeID { + col.Type.TypeMeta = types.UserDefinedTypeMetadata{} + } } } if err := typedesc.HydrateTypesInTableDescriptor( diff --git a/pkg/sql/opt/optbuilder/builder.go b/pkg/sql/opt/optbuilder/builder.go index 441576614e43..5228724bd4bf 100644 --- a/pkg/sql/opt/optbuilder/builder.go +++ b/pkg/sql/opt/optbuilder/builder.go @@ -429,7 +429,11 @@ func (b *Builder) maybeTrackRegclassDependenciesForViews(texpr tree.TypedExpr) { func (b *Builder) maybeTrackUserDefinedTypeDepsForViews(texpr tree.TypedExpr) { if b.trackViewDeps { if texpr.ResolvedType().UserDefined() { - for id := range typedesc.GetTypeDescriptorClosure(texpr.ResolvedType()) { + children, err := typedesc.GetTypeDescriptorClosure(texpr.ResolvedType()) + if err != nil { + panic(err) + } + for id := range children { b.viewTypeDeps.Add(int(id)) } } diff --git a/pkg/sql/opt/testutils/testcat/test_catalog.go b/pkg/sql/opt/testutils/testcat/test_catalog.go index 466c92b0d691..a17eecb91659 100644 --- a/pkg/sql/opt/testutils/testcat/test_catalog.go +++ b/pkg/sql/opt/testutils/testcat/test_catalog.go @@ -794,7 +794,11 @@ func (tt *Table) CollectTypes(ord int) (descpb.IDs, error) { ids := make(descpb.IDs, 0, len(visitor.OIDs)) for collectedOid := range visitor.OIDs { - ids = append(ids, typedesc.UserDefinedTypeOIDToID(collectedOid)) + id, err := typedesc.UserDefinedTypeOIDToID(collectedOid) + if err != nil { + return nil, err + } + ids = append(ids, id) } return ids, nil } diff --git a/pkg/sql/opt_catalog.go b/pkg/sql/opt_catalog.go index beada93cbb70..293ad01cd19c 100644 --- a/pkg/sql/opt_catalog.go +++ b/pkg/sql/opt_catalog.go @@ -2274,7 +2274,11 @@ func collectTypes(col catalog.Column) (descpb.IDs, error) { ids := make(descpb.IDs, 0, len(visitor.OIDs)) for collectedOid := range visitor.OIDs { - ids = append(ids, typedesc.UserDefinedTypeOIDToID(collectedOid)) + id, err := typedesc.UserDefinedTypeOIDToID(collectedOid) + if err != nil { + return nil, err + } + ids = append(ids, id) } return ids, nil } diff --git a/pkg/sql/pg_catalog.go b/pkg/sql/pg_catalog.go index c7a52038a0ac..0528acd06725 100644 --- a/pkg/sql/pg_catalog.go +++ b/pkg/sql/pg_catalog.go @@ -30,7 +30,6 @@ import ( "github.com/cockroachdb/cockroach/pkg/sql/catalog/schemaexpr" "github.com/cockroachdb/cockroach/pkg/sql/catalog/tabledesc" "github.com/cockroachdb/cockroach/pkg/sql/catalog/typedesc" - "github.com/cockroachdb/cockroach/pkg/sql/oidext" "github.com/cockroachdb/cockroach/pkg/sql/parser" "github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgcode" "github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgerror" @@ -2437,17 +2436,20 @@ https://www.postgresql.org/docs/9.5/catalog-pg-type.html`, return true, nil } - // This oid is not a user-defined type and we didn't find it in the - // map of predefined types, return false. Note that in common usage we - // only really expect the value 0 here (which cockroach uses internally - // in the typelem field amongst others). Users, however, may join on - // this index with any value. - if ooid <= oidext.CockroachPredefinedOIDMax { + // Check if it is a user defined type. + if !types.IsOIDUserDefinedType(ooid) { + // This oid is not a user-defined type and we didn't find it in the + // map of predefined types, return false. Note that in common usage we + // only really expect the value 0 here (which cockroach uses internally + // in the typelem field amongst others). Users, however, may join on + // this index with any value. return false, nil } - // Check if it is a user defined type. - id := typedesc.UserDefinedTypeOIDToID(ooid) + id, err := typedesc.UserDefinedTypeOIDToID(ooid) + if err != nil { + return false, err + } typDesc, err := p.Descriptors().GetImmutableTypeByID(ctx, p.txn, id, tree.ObjectLookupFlags{}) if err != nil { if errors.Is(err, catalog.ErrDescriptorNotFound) { diff --git a/pkg/sql/resolver.go b/pkg/sql/resolver.go index 0e55c2f06d68..0b901ba4a06c 100644 --- a/pkg/sql/resolver.go +++ b/pkg/sql/resolver.go @@ -276,7 +276,16 @@ func (p *planner) IsTypeVisible( if _, ok := types.OidToType[typeID]; ok { return true, true, nil } - typName, _, err := p.GetTypeDescriptor(ctx, typedesc.UserDefinedTypeOIDToID(typeID)) + + if !types.IsOIDUserDefinedType(typeID) { + return false, false, nil //nolint:returnerrcheck + } + + id, err := typedesc.UserDefinedTypeOIDToID(typeID) + if err != nil { + return false, false, err + } + typName, _, err := p.GetTypeDescriptor(ctx, id) if err != nil { // If a "not found" error happened here, we return "not exists" rather than // the error. @@ -361,7 +370,11 @@ func (p *planner) ResolveType( // ResolveTypeByOID implements the tree.TypeResolver interface. func (p *planner) ResolveTypeByOID(ctx context.Context, oid oid.Oid) (*types.T, error) { - name, desc, err := p.GetTypeDescriptor(ctx, typedesc.UserDefinedTypeOIDToID(oid)) + id, err := typedesc.UserDefinedTypeOIDToID(oid) + if err != nil { + return nil, err + } + name, desc, err := p.GetTypeDescriptor(ctx, id) if err != nil { return nil, err } diff --git a/pkg/sql/sem/tree/casts.go b/pkg/sql/sem/tree/casts.go index 15011abd173a..63df8464be0e 100644 --- a/pkg/sql/sem/tree/casts.go +++ b/pkg/sql/sem/tree/casts.go @@ -1302,7 +1302,11 @@ func performIntToOidCast(ctx *EvalContext, t *types.T, v DInt) (Datum, error) { ret := &DOid{semanticType: t, DInt: v} if typ, ok := types.OidToType[oid.Oid(v)]; ok { ret.name = typ.PGName() - } else if typ, err := ctx.Planner.ResolveTypeByOID(ctx.Context, oid.Oid(v)); err == nil { + } else if types.IsOIDUserDefinedType(oid.Oid(v)) { + typ, err := ctx.Planner.ResolveTypeByOID(ctx.Context, oid.Oid(v)) + if err != nil { + return nil, err + } ret.name = typ.PGName() } return ret, nil diff --git a/pkg/sql/type_change.go b/pkg/sql/type_change.go index 569c96d99432..26750b1a7f7e 100644 --- a/pkg/sql/type_change.go +++ b/pkg/sql/type_change.go @@ -658,7 +658,10 @@ func findUsagesOfEnumValue( if !ok { return true, expr, nil } - id := typedesc.UserDefinedTypeOIDToID(typeOid.OID) + id, err := typedesc.UserDefinedTypeOIDToID(typeOid.OID) + if err != nil { + return false, expr, err + } if id != typeID { return true, expr, nil } @@ -680,8 +683,12 @@ func findUsagesOfEnumValue( if !ok { return true, expr, nil } + id, err := typedesc.UserDefinedTypeOIDToID(typeOid.OID) + if err != nil { + return false, expr, err + } // -1 since the type of this CastExpr is the array type. - id := typedesc.UserDefinedTypeOIDToID(typeOid.OID) - 1 + id = id - 1 if id != typeID { return true, expr, nil } @@ -726,7 +733,10 @@ func findUsagesOfEnumValueInViewQuery( if !ok { return true, expr, nil } - id := typedesc.UserDefinedTypeOIDToID(typeOid.OID) + id, err := typedesc.UserDefinedTypeOIDToID(typeOid.OID) + if err != nil { + return false, expr, err + } if id != typeID { return true, expr, nil } @@ -816,22 +826,28 @@ func (t *typeSchemaChanger) canRemoveEnumValue( } } - if typeDesc.ID == typedesc.GetTypeDescID(col.GetType()) { - if !firstClause { - query.WriteString(" OR") + if col.GetType().UserDefined() { + tid, terr := typedesc.GetUserDefinedTypeDescID(col.GetType()) + if terr != nil { + return terr } - sqlPhysRep, err := convertToSQLStringRepresentation(member.PhysicalRepresentation) - if err != nil { - return err + if typeDesc.ID == tid { + if !firstClause { + query.WriteString(" OR") + } + sqlPhysRep, err := convertToSQLStringRepresentation(member.PhysicalRepresentation) + if err != nil { + return err + } + colName := col.ColName() + query.WriteString(fmt.Sprintf( + " t.%s = %s", + colName.String(), + sqlPhysRep, + )) + firstClause = false + validationQueryConstructed = true } - colName := col.ColName() - query.WriteString(fmt.Sprintf( - " t.%s = %s", - colName.String(), - sqlPhysRep, - )) - firstClause = false - validationQueryConstructed = true } } query.WriteString(" LIMIT 1") @@ -923,7 +939,14 @@ func (t *typeSchemaChanger) canRemoveEnumValueFromArrayUsages( // ) WHERE unnest = 'enum_value' firstClause := true for _, col := range desc.PublicColumns() { - if arrayTypeDesc.GetID() == typedesc.GetTypeDescID(col.GetType()) { + if !col.GetType().UserDefined() { + continue + } + tid, terr := typedesc.GetUserDefinedTypeDescID(col.GetType()) + if terr != nil { + return terr + } + if arrayTypeDesc.GetID() == tid { if !firstClause { unionUnnests.WriteString(" UNION ") } diff --git a/pkg/util/span/frontier.go b/pkg/util/span/frontier.go index fee12d61ad29..fa251cfab974 100644 --- a/pkg/util/span/frontier.go +++ b/pkg/util/span/frontier.go @@ -25,11 +25,11 @@ import ( // frontierEntry represents a timestamped span. It is used as the nodes in both // the interval tree and heap needed to keep the Frontier. type frontierEntry struct { - id int64 - keys interval.Range - span roachpb.Span - ts hlc.Timestamp - + id int64 + keys interval.Range + span roachpb.Span + ts hlc.Timestamp + updateTS hlc.Timestamp // The index of the item in the frontierHeap, maintained by the // heap.Interface methods. index int @@ -104,6 +104,10 @@ type Frontier struct { minHeap frontierHeap idAlloc int64 + + // getUpdateTimestamp, if not nil, returns the hlc timestamp when + // updating frontier. + getUpdateTimestamp func() hlc.Timestamp } // makeSpan copies intervals start/end points and returns a span. @@ -137,6 +141,11 @@ func MakeFrontier(spans ...roachpb.Span) (*Frontier, error) { return f, nil } +// TrackUpdateTimestamp asks frontier to keep track of span update timestamp. +func (f *Frontier) TrackUpdateTimestamp(now func() hlc.Timestamp) { + f.getUpdateTimestamp = now +} + // Frontier returns the minimum timestamp being tracked. func (f *Frontier) Frontier() hlc.Timestamp { if f.minHeap.Len() == 0 { @@ -222,19 +231,23 @@ func extendRangeToTheRight( } func (f *Frontier) insert(span roachpb.Span, insertTS hlc.Timestamp) error { - const continueMatch = false - // Set of frontier entries to add and remove. var toAdd, toRemove []*frontierEntry + var updateTS hlc.Timestamp + if f.getUpdateTimestamp != nil { + updateTS = f.getUpdateTimestamp() + } + // addEntry adds frontier entry to the toAdd list. addEntry := func(r interval.Range, ts hlc.Timestamp) { sp := makeSpan(r) toAdd = append(toAdd, &frontierEntry{ - id: f.idAlloc, - span: sp, - keys: sp.AsRange(), - ts: ts, + id: f.idAlloc, + span: sp, + keys: sp.AsRange(), + ts: ts, + updateTS: updateTS, }) f.idAlloc++ } @@ -289,10 +302,11 @@ func (f *Frontier) insert(span roachpb.Span, insertTS hlc.Timestamp) error { todoRange.Start = overlap.keys.Start } - // Fast case: we already recorded higher timestamp for this overlap. - if insertTS.LessEq(overlap.ts) { + // Fast case: we already recorded higher timestamp for this overlap + if insertTS.Less(overlap.ts) { + overlap.updateTS = updateTS todoRange.Start = overlap.keys.End - return continueMatch + return ContinueMatch.asBool() } // At this point, we know that overlap timestamp is not ahead of the insertTS @@ -333,7 +347,7 @@ func (f *Frontier) insert(span roachpb.Span, insertTS hlc.Timestamp) error { consumePrefix(overlap.keys.End) } - return continueMatch + return ContinueMatch.asBool() }, span.AsRange()) // Add remaining pending range. @@ -356,16 +370,87 @@ func (f *Frontier) insert(span roachpb.Span, insertTS hlc.Timestamp) error { return nil } +// OpResult is the result of the Operation callback. +type OpResult bool + +const ( + // ContinueMatch signals DoMatching should continue. + ContinueMatch OpResult = false + // StopMatch signals DoMatching should stop. + StopMatch OpResult = true +) + +func (r OpResult) asBool() bool { + return bool(r) +} + +// An Operation is a function that operates on a frontier spans. If done is returned true, the +// Operation is indicating that no further work needs to be done and so the DoMatching function +// should traverse no further. +type Operation func(roachpb.Span, hlc.Timestamp) (done OpResult) + // Entries invokes the given callback with the current timestamp for each // component span in the tracked span set. -func (f *Frontier) Entries(fn func(roachpb.Span, hlc.Timestamp)) { +func (f *Frontier) Entries(fn Operation) { f.tree.Do(func(i interval.Interface) bool { spe := i.(*frontierEntry) - fn(spe.span, spe.ts) - return false + return fn(spe.span, spe.ts).asBool() + }) +} + +// UpdatedEntries is similar to Entries, but invokes provided fn only if +// the span update timestamp is newer than cutoff. +// This function requires TrackUpdateTimestamps to be called on this Frontier. +func (f *Frontier) UpdatedEntries(cutoff hlc.Timestamp, fn Operation) { + f.tree.Do(func(i interval.Interface) bool { + e := i.(*frontierEntry) + if cutoff.Less(e.updateTS) { + return fn(e.span, e.ts).asBool() + } + return ContinueMatch.asBool() }) } +// SpanEntries invokes op for each sub-span of the specified span with the +// timestamp as observed by this frontier. +// +// Time +// 5| .b__c . +// 4| . h__k . +// 3| . e__f . +// 1 ---a----------------------m---q-- Frontier +// |___________span___________| +// +// In the above example, frontier tracks [b, m) and the current frontier +// timestamp is 1. SpanEntries for span [a-q) will invoke op with: +// ([b-c), 5), ([c-e), 1), ([e-f), 3], ([f, h], 1) ([h, k), 4), ([k, m), 1). +// Note: neither [a-b) nor [m, q) will be emitted since they fall outside the spans +// tracked by this frontier. +func (f *Frontier) SpanEntries(span roachpb.Span, op Operation) { + todoRange := span.AsRange() + + f.tree.DoMatching(func(i interval.Interface) bool { + e := i.(*frontierEntry) + + // Skip untracked portion. + if todoRange.Start.Compare(e.keys.Start) < 0 { + todoRange.Start = e.keys.Start + } + + end := e.keys.End + if e.keys.End.Compare(todoRange.End) > 0 { + end = todoRange.End + } + + if op(roachpb.Span{Key: roachpb.Key(todoRange.Start), EndKey: roachpb.Key(end)}, e.ts) == StopMatch { + return StopMatch.asBool() + } + todoRange.Start = end + return ContinueMatch.asBool() + }, span.AsRange()) +} + +// String implements Stringer. func (f *Frontier) String() string { var buf strings.Builder f.tree.Do(func(i interval.Interface) bool { diff --git a/pkg/util/span/frontier_test.go b/pkg/util/span/frontier_test.go index 38acadb35931..b31f5d600de7 100644 --- a/pkg/util/span/frontier_test.go +++ b/pkg/util/span/frontier_test.go @@ -29,11 +29,12 @@ import ( func (f *Frontier) entriesStr() string { var buf strings.Builder - f.Entries(func(sp roachpb.Span, ts hlc.Timestamp) { + f.Entries(func(sp roachpb.Span, ts hlc.Timestamp) OpResult { if buf.Len() != 0 { buf.WriteString(` `) } fmt.Fprintf(&buf, `%s@%d`, sp, ts.WallTime) + return ContinueMatch }) return buf.String() } @@ -231,7 +232,7 @@ func TestSpanFrontierDisjointSpans(t *testing.T) { expectEntries(`{a-b}@3 {c-d}@3 {d-e}@2`) // Advance span that overlaps all the spans tracked by this frontier. - // {c-d} and {d-e} should collaps. + // {c-d} and {d-e} should collapse. forwardFrontier(roachpb.Span{Key: roachpb.Key(`0`), EndKey: roachpb.Key(`q`)}, 4). expectedAdvanced(true). expectFrontier(4). @@ -303,6 +304,126 @@ func TestSequentialSpans(t *testing.T) { require.Equal(t, strings.Join(expectedRanges, " "), f.entriesStr()) } +func TestSpanEntries(t *testing.T) { + defer leaktest.AfterTest(t)() + + key := func(c byte) roachpb.Key { + return roachpb.Key{c} + } + mkspan := func(start, end byte) roachpb.Span { + return roachpb.Span{Key: key(start), EndKey: key(end)} + } + + spAZ := mkspan('A', 'Z') + f, err := MakeFrontier(spAZ) + require.NoError(t, err) + + advance := func(s roachpb.Span, wall int64) { + _, err := f.Forward(s, hlc.Timestamp{WallTime: wall}) + require.NoError(t, err) + } + + spanEntries := func(sp roachpb.Span) string { + var buf strings.Builder + f.SpanEntries(sp, func(s roachpb.Span, ts hlc.Timestamp) OpResult { + if buf.Len() != 0 { + buf.WriteString(` `) + } + fmt.Fprintf(&buf, `%s@%d`, s, ts.WallTime) + return ContinueMatch + }) + return buf.String() + } + + // Nothing overlaps span fully to the left of frontier. + require.Equal(t, ``, spanEntries(mkspan('0', '9'))) + // Nothing overlaps span fully to the right of the frontier. + require.Equal(t, ``, spanEntries(mkspan('a', 'z'))) + + // Span overlaps entire frontier. + require.Equal(t, `{A-Z}@0`, spanEntries(spAZ)) + advance(spAZ, 1) + require.Equal(t, `{A-Z}@1`, spanEntries(spAZ)) + + // Span overlaps part of the frontier, with left part outside frontier. + require.Equal(t, `{A-C}@1`, spanEntries(mkspan('0', 'C'))) + + // Span overlaps part of the frontier, with right part outside frontier. + require.Equal(t, `{Q-Z}@1`, spanEntries(mkspan('Q', 'c'))) + + // Span fully inside frontier. + require.Equal(t, `{P-W}@1`, spanEntries(mkspan('P', 'W'))) + + // Advance part of the frontier. + advance(mkspan('C', 'E'), 2) + advance(mkspan('H', 'M'), 5) + advance(mkspan('N', 'Q'), 3) + + // Span overlaps various parts of the frontier. + require.Equal(t, + `{A-C}@1 {C-E}@2 {E-H}@1 {H-M}@5 {M-N}@1 {N-P}@3`, + spanEntries(mkspan('3', 'P'))) +} + +func TestUpdatedEntries(t *testing.T) { + defer leaktest.AfterTest(t)() + + key := func(c byte) roachpb.Key { + return roachpb.Key{c} + } + mkspan := func(start, end byte) roachpb.Span { + return roachpb.Span{Key: key(start), EndKey: key(end)} + } + + spAZ := mkspan('A', 'Z') + f, err := MakeFrontier(spAZ) + require.NoError(t, err) + + var wall int64 = 0 + advance := func(s roachpb.Span, newWall int64) { + wall = newWall + _, err := f.Forward(s, hlc.Timestamp{WallTime: wall}) + require.NoError(t, err) + } + + updatedEntries := func(cutoff int64) string { + var buf strings.Builder + f.UpdatedEntries(hlc.Timestamp{WallTime: cutoff}, func(s roachpb.Span, ts hlc.Timestamp) OpResult { + if buf.Len() != 0 { + buf.WriteString(` `) + } + fmt.Fprintf(&buf, `%s@%d`, s, ts.WallTime) + return ContinueMatch + }) + return buf.String() + } + + // If we haven't configured frontier to keep track of updates, we expect to see + // all spans as updated. + require.Equal(t, ``, updatedEntries(0)) + require.Equal(t, ``, updatedEntries(1)) + advance(mkspan('C', 'E'), 2) + require.Equal(t, ``, updatedEntries(1)) + + f.TrackUpdateTimestamp(func() hlc.Timestamp { return hlc.Timestamp{WallTime: wall} }) + + advance(mkspan('C', 'E'), 3) + require.Equal(t, `{C-E}@3`, updatedEntries(0)) + require.Equal(t, `{C-E}@3`, updatedEntries(2)) + advance(mkspan('D', 'E'), 4) + require.Equal(t, `{C-D}@3 {D-E}@4`, updatedEntries(3)) + + // Nothing was updated after t=4 + require.Equal(t, ``, updatedEntries(4)) + + advance(mkspan('C', 'E'), 5) + require.Equal(t, `{C-E}@5`, updatedEntries(4)) + + advance(spAZ, 5) + require.Equal(t, `{A-Z}@5`, updatedEntries(4)) + require.Equal(t, ``, updatedEntries(5)) +} + // symbols that can make up spans. var spanSymbols = []byte("@$0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ")