Skip to content

Commit

Permalink
sanitize column name
Browse files Browse the repository at this point in the history
  • Loading branch information
zhangyangyu committed May 7, 2022
1 parent ed9cbe7 commit 36015f3
Show file tree
Hide file tree
Showing 4 changed files with 105 additions and 38 deletions.
72 changes: 63 additions & 9 deletions cdc/sink/codec/avro.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (
"context"
"encoding/binary"
"encoding/json"
"fmt"
"math/big"
"strconv"
"strings"
Expand Down Expand Up @@ -181,11 +182,11 @@ func (a *AvroEventBatchEncoder) avroEncode(
}
}

var fqdn string = e.Table.Schema + "." + e.Table.Table
qualifiedName := getQualifiedNameFromTableName(e.Table)

schemaGen := func() (string, error) {
schema, err := rowToAvroSchema(
fqdn,
qualifiedName,
cols,
colInfos,
enableTiDBExtension,
Expand All @@ -200,7 +201,7 @@ func (a *AvroEventBatchEncoder) avroEncode(

avroCodec, registryID, err := schemaManager.GetCachedOrRegister(
ctx,
fqdn,
qualifiedName,
e.TableInfoVersion,
schemaGen,
)
Expand Down Expand Up @@ -289,6 +290,59 @@ func getTiDBTypeFromColumn(col *model.Column) string {
return tt
}

const (
replacementChar = "_"
numberPrefix = "_"
)

// debezium-core/src/main/java/io/debezium/schema/FieldNameSelector.java
// https://avro.apache.org/docs/current/spec.html#names
func sanitizeColumnName(name string) string {
changed := false
var sb strings.Builder
for i, c := range name {
if i == 0 && (c >= '0' && c <= '9') {
sb.WriteString(numberPrefix)
sb.WriteRune(c)
changed = true
} else if !(c == '_' ||
('a' <= c && c <= 'z') ||
('A' <= c && c <= 'Z') ||
('0' <= c && c <= '9')) {
sb.WriteString(replacementChar)
changed = true
} else {
sb.WriteRune(c)
}
}

sanitizedName := sb.String()
if changed {
log.Warn(
fmt.Sprintf(
"Field '%s' name potentially not safe for serialization, replaced with '%s'",
name,
sanitizedName,
),
)
}
return sanitizedName
}

// https://github.com/debezium/debezium/blob/9f7ede0e0695f012c6c4e715e96aed85eecf6b5f \
// /debezium-connector-mysql/src/main/java/io/debezium/connector/mysql/antlr/ \
// MySqlAntlrDdlParser.java#L374
func escapeEnumAndSetOptions(option string) string {
option = strings.ReplaceAll(option, ",", "\\,")
option = strings.ReplaceAll(option, "\\'", "'")
option = strings.ReplaceAll(option, "''", "'")
return option
}

func getQualifiedNameFromTableName(tableName *model.TableName) string {
return tableName.Schema + "." + tableName.Table
}

type avroSchema struct {
Type string `json:"type"`
Parameters map[string]string `json:"connect.parameters"`
Expand All @@ -302,7 +356,7 @@ type avroLogicalTypeSchema struct {
}

func rowToAvroSchema(
fqdn string,
qualifiedName string,
columnInfo []*model.Column,
colInfos []rowcodec.ColInfo,
enableTiDBExtension bool,
Expand All @@ -311,7 +365,7 @@ func rowToAvroSchema(
) (string, error) {
top := avroSchemaTop{
Tp: "record",
Name: fqdn,
Name: qualifiedName,
Fields: nil,
}

Expand All @@ -326,7 +380,7 @@ func rowToAvroSchema(
return "", err
}
field := make(map[string]interface{})
field["name"] = col.Name
field["name"] = sanitizeColumnName(col.Name)
if col.Flag.IsNullable() {
field["type"] = []interface{}{"null", avroType}
field["default"] = nil
Expand Down Expand Up @@ -388,9 +442,9 @@ func rowToAvroData(

// https://pkg.go.dev/github.com/linkedin/goavro/v2#Union
if col.Flag.IsNullable() {
ret[col.Name] = goavro.Union(str, data)
ret[sanitizeColumnName(col.Name)] = goavro.Union(str, data)
} else {
ret[col.Name] = data
ret[sanitizeColumnName(col.Name)] = data
}
}

Expand Down Expand Up @@ -503,7 +557,7 @@ func columnToAvroSchema(
case mysql.TypeEnum, mysql.TypeSet:
es := make([]string, 0, len(ft.Elems))
for _, e := range ft.Elems {
e = strings.ReplaceAll(e, ",", "\\,")
e = escapeEnumAndSetOptions(e)
es = append(es, e)
}
return avroSchema{
Expand Down
19 changes: 16 additions & 3 deletions cdc/sink/codec/avro_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -633,7 +633,7 @@ func TestRowToAvroSchema(t *testing.T) {
Schema: "testdb",
Table: "rowtoavroschema",
}
fqdn := table.Schema + "." + table.Table
qualifiedName := getQualifiedNameFromTableName(&table)
var cols []*model.Column = make([]*model.Column, 0)
var colInfos []rowcodec.ColInfo = make([]rowcodec.ColInfo, 0)

Expand All @@ -648,13 +648,13 @@ func TestRowToAvroSchema(t *testing.T) {
colInfos = append(colInfos, v.colInfo)
}

schema, err := rowToAvroSchema(fqdn, cols, colInfos, false, "precise", "long")
schema, err := rowToAvroSchema(qualifiedName, cols, colInfos, false, "precise", "long")
require.NoError(t, err)
require.Equal(t, expectedSchemaWithoutExtension, indentJSON(schema))
_, err = goavro.NewCodec(schema)
require.NoError(t, err)

schema, err = rowToAvroSchema(fqdn, cols, colInfos, true, "precise", "long")
schema, err = rowToAvroSchema(qualifiedName, cols, colInfos, true, "precise", "long")
require.NoError(t, err)
require.Equal(t, expectedSchemaWithExtension, indentJSON(schema))
_, err = goavro.NewCodec(schema)
Expand Down Expand Up @@ -833,3 +833,16 @@ func TestAvroEnvelope(t *testing.T) {
require.True(t, exists)
require.Equal(t, int32(7), id)
}

func TestSanitizeColumnName(t *testing.T) {
t.Parallel()

require.Equal(t, "normalColumnName123", sanitizeColumnName("normalColumnName123"))
require.Equal(
t,
"_1ColumnNameStartWithNumber",
sanitizeColumnName("1ColumnNameStartWithNumber"),
)
require.Equal(t, "A_B", sanitizeColumnName("A.B"))
require.Equal(t, "columnNameWith__", sanitizeColumnName("columnNameWith中文"))
}
22 changes: 11 additions & 11 deletions cdc/sink/codec/schema_registry.go
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ var regexRemoveSpaces = regexp.MustCompile(`\s`)
// Register a schema in schema registry, no cache
func (m *AvroSchemaManager) Register(
ctx context.Context,
fqdn string,
qualifiedName string,
codec *goavro.Codec,
) (int, error) {
// The Schema Registry expects the JSON to be without newline characters
Expand All @@ -143,7 +143,7 @@ func (m *AvroSchemaManager) Register(
)
}
uri := m.registryURL + "/subjects/" + url.QueryEscape(
m.tableNameToSchemaSubject(fqdn),
m.tableNameToSchemaSubject(qualifiedName),
) + "/versions"
log.Debug("Registering schema", zap.String("uri", uri), zap.ByteString("payload", payload))

Expand Down Expand Up @@ -217,10 +217,10 @@ func (m *AvroSchemaManager) Register(
// NOT USED for now, reserved for future use.
func (m *AvroSchemaManager) Lookup(
ctx context.Context,
fqdn string,
qualifiedName string,
tiSchemaID uint64,
) (*goavro.Codec, int, error) {
key := m.tableNameToSchemaSubject(fqdn)
key := m.tableNameToSchemaSubject(qualifiedName)
m.cacheRWLock.RLock()
if entry, exists := m.cache[key]; exists && entry.tiSchemaID == tiSchemaID {
log.Info("Avro schema lookup cache hit",
Expand Down Expand Up @@ -336,11 +336,11 @@ type SchemaGenerator func() (string, error)
// cache is out-of-sync with schema registry, we could reload it.
func (m *AvroSchemaManager) GetCachedOrRegister(
ctx context.Context,
fqdn string,
qualifiedName string,
tiSchemaID uint64,
schemaGen SchemaGenerator,
) (*goavro.Codec, int, error) {
key := m.tableNameToSchemaSubject(fqdn)
key := m.tableNameToSchemaSubject(qualifiedName)
m.cacheRWLock.RLock()
if entry, exists := m.cache[key]; exists && entry.tiSchemaID == tiSchemaID {
log.Debug("Avro schema GetCachedOrRegister cache hit",
Expand Down Expand Up @@ -372,7 +372,7 @@ func (m *AvroSchemaManager) GetCachedOrRegister(
)
}

id, err := m.Register(ctx, fqdn, codec)
id, err := m.Register(ctx, qualifiedName, codec)
if err != nil {
return nil, 0, errors.Annotate(
cerror.WrapError(
Expand Down Expand Up @@ -403,8 +403,8 @@ func (m *AvroSchemaManager) GetCachedOrRegister(
// ClearRegistry clears the Registry subject for the given table. Should be idempotent.
// Exported for testing.
// NOT USED for now, reserved for future use.
func (m *AvroSchemaManager) ClearRegistry(ctx context.Context, fqdn string) error {
uri := m.registryURL + "/subjects/" + url.QueryEscape(m.tableNameToSchemaSubject(fqdn))
func (m *AvroSchemaManager) ClearRegistry(ctx context.Context, qualifiedName string) error {
uri := m.registryURL + "/subjects/" + url.QueryEscape(m.tableNameToSchemaSubject(qualifiedName))
req, err := http.NewRequestWithContext(ctx, "DELETE", uri, nil)
if err != nil {
log.Error("Could not construct request for clearRegistry", zap.String("uri", uri))
Expand Down Expand Up @@ -498,9 +498,9 @@ func httpRetry(
return resp, nil
}

func (m *AvroSchemaManager) tableNameToSchemaSubject(fqdn string) string {
func (m *AvroSchemaManager) tableNameToSchemaSubject(qualifiedName string) string {
// obey the RecordNameStrategy but generate a global unique subject
// https://docs.confluent.io/platform/current/schema-registry/serdes-develop/index.html \
// #subject-name-strategy
return fqdn + m.subjectSuffix
return qualifiedName + m.subjectSuffix
}
30 changes: 15 additions & 15 deletions cdc/sink/codec/schema_registry_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -177,12 +177,12 @@ func TestSchemaRegistry(t *testing.T) {
)
require.NoError(t, err)

fqdn := table.Schema + "." + table.Table
qualifiedName := getQualifiedNameFromTableName(&table)

err = manager.ClearRegistry(getTestingContext(), fqdn)
err = manager.ClearRegistry(getTestingContext(), qualifiedName)
require.NoError(t, err)

_, _, err = manager.Lookup(getTestingContext(), fqdn, 1)
_, _, err = manager.Lookup(getTestingContext(), qualifiedName, 1)
require.Regexp(t, `.*not\sfound.*`, err)

codec, err := goavro.NewCodec(`{
Expand All @@ -198,12 +198,12 @@ func TestSchemaRegistry(t *testing.T) {
}`)
require.NoError(t, err)

_, err = manager.Register(getTestingContext(), fqdn, codec)
_, err = manager.Register(getTestingContext(), qualifiedName, codec)
require.NoError(t, err)

var id int
for i := 0; i < 2; i++ {
_, id, err = manager.Lookup(getTestingContext(), fqdn, 1)
_, id, err = manager.Lookup(getTestingContext(), qualifiedName, 1)
require.NoError(t, err)
require.Greater(t, id, 0)
}
Expand All @@ -228,10 +228,10 @@ func TestSchemaRegistry(t *testing.T) {
]
}`)
require.NoError(t, err)
_, err = manager.Register(getTestingContext(), fqdn, codec)
_, err = manager.Register(getTestingContext(), qualifiedName, codec)
require.NoError(t, err)

codec2, id2, err := manager.Lookup(getTestingContext(), fqdn, 999)
codec2, id2, err := manager.Lookup(getTestingContext(), qualifiedName, 999)
require.NoError(t, err)
require.NotEqual(t, id, id2)
require.Equal(t, codec.CanonicalSchema(), codec2.CanonicalSchema())
Expand All @@ -255,7 +255,7 @@ func TestSchemaRegistryIdempotent(t *testing.T) {
Schema: "testdb",
Table: "test",
}
fqdn := table.Schema + "." + table.Table
qualifiedName := getQualifiedNameFromTableName(&table)

manager, err := NewAvroSchemaManager(
getTestingContext(),
Expand All @@ -265,7 +265,7 @@ func TestSchemaRegistryIdempotent(t *testing.T) {
)
require.NoError(t, err)
for i := 0; i < 20; i++ {
err = manager.ClearRegistry(getTestingContext(), fqdn)
err = manager.ClearRegistry(getTestingContext(), qualifiedName)
require.NoError(t, err)
}

Expand All @@ -292,7 +292,7 @@ func TestSchemaRegistryIdempotent(t *testing.T) {

id := 0
for i := 0; i < 20; i++ {
id1, err := manager.Register(getTestingContext(), fqdn, codec)
id1, err := manager.Register(getTestingContext(), qualifiedName, codec)
require.NoError(t, err)
require.True(t, id == 0 || id == id1)
id = id1
Expand Down Expand Up @@ -341,20 +341,20 @@ func TestGetCachedOrRegister(t *testing.T) {
]
}`, nil
}
fqdn := table.Schema + "." + table.Table
qualifiedName := getQualifiedNameFromTableName(&table)

codec, id, err := manager.GetCachedOrRegister(getTestingContext(), fqdn, 1, schemaGen)
codec, id, err := manager.GetCachedOrRegister(getTestingContext(), qualifiedName, 1, schemaGen)
require.NoError(t, err)
require.Greater(t, id, 0)
require.NotNil(t, codec)
require.Equal(t, 1, called)

codec1, _, err := manager.GetCachedOrRegister(getTestingContext(), fqdn, 1, schemaGen)
codec1, _, err := manager.GetCachedOrRegister(getTestingContext(), qualifiedName, 1, schemaGen)
require.NoError(t, err)
require.True(t, codec == codec1) // check identity
require.Equal(t, 1, called)

codec2, _, err := manager.GetCachedOrRegister(getTestingContext(), fqdn, 2, schemaGen)
codec2, _, err := manager.GetCachedOrRegister(getTestingContext(), qualifiedName, 2, schemaGen)
require.NoError(t, err)
require.NotEqual(t, codec, codec2)
require.Equal(t, 2, called)
Expand Down Expand Up @@ -390,7 +390,7 @@ func TestGetCachedOrRegister(t *testing.T) {
for j := 0; j < 100; j++ {
codec, id, err := manager.GetCachedOrRegister(
getTestingContext(),
fqdn,
qualifiedName,
uint64(finalI),
schemaGen,
)
Expand Down

0 comments on commit 36015f3

Please sign in to comment.