Skip to content

Commit

Permalink
Merge pull request #92 from planetscale/fix-enum-sets
Browse files Browse the repository at this point in the history
Fix enum & set values after COPY phase
  • Loading branch information
notfelineit authored Dec 14, 2023
2 parents 5baac4f + 36ee11f commit 54a459f
Show file tree
Hide file tree
Showing 2 changed files with 122 additions and 3 deletions.
91 changes: 89 additions & 2 deletions cmd/internal/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@ package internal

import (
"encoding/base64"
"regexp"
"strconv"
"strings"

"github.com/pkg/errors"
psdbconnect "github.com/planetscale/airbyte-source/proto/psdbconnect/v1alpha1"
"github.com/planetscale/psdb/core/codec"
Expand Down Expand Up @@ -129,7 +133,6 @@ func TableCursorToSerializedCursor(cursor *psdbconnect.TableCursor) (*Serialized

func QueryResultToRecords(qr *sqltypes.Result) []map[string]interface{} {
data := make([]map[string]interface{}, 0, len(qr.Rows))

columns := make([]string, 0, len(qr.Fields))
for _, field := range qr.Fields {
columns = append(columns, field.Name)
Expand All @@ -139,7 +142,7 @@ func QueryResultToRecords(qr *sqltypes.Result) []map[string]interface{} {
record := make(map[string]interface{})
for idx, val := range row {
if idx < len(columns) {
record[columns[idx]] = val
record[columns[idx]] = parseValue(val, qr.Fields[idx].GetColumnType())
}
}
data = append(data, record)
Expand All @@ -148,6 +151,90 @@ func QueryResultToRecords(qr *sqltypes.Result) []map[string]interface{} {
return data
}

// After the initial COPY phase, enum and set values may appear as an index instead of a value.
// For example, a value might look like a "1" instead of "apple" in an enum('apple','banana','orange') column)
func parseValue(val sqltypes.Value, columnType string) sqltypes.Value {
if strings.HasPrefix(columnType, "enum") {
values := parseEnumOrSetValues(columnType)
return mapEnumValue(val, values)
} else if strings.HasPrefix(columnType, "set") {
values := parseEnumOrSetValues(columnType)
return mapSetValue(val, values)
}

return val
}

// Takes enum or set column type like ENUM('a','b','c') or SET('a','b','c')
// and returns a slice of values []string{'a', 'b', 'c'}
func parseEnumOrSetValues(columnType string) []string {
values := []string{}

re := regexp.MustCompile(`\((.+)\)`)
res := re.FindString(columnType)
res = strings.Trim(res, "(")
res = strings.Trim(res, ")")
for _, r := range strings.Split(res, ",") {
values = append(values, strings.Trim(r, "'"))
}

return values
}

func mapSetValue(value sqltypes.Value, values []string) sqltypes.Value {
parsedValue := value.ToString()
parsedInt, err := strconv.ParseInt(parsedValue, 10, 64)
if err != nil {
// if value is not an integer, we just return the original value
return value
}
mappedValues := []string{}
// SET mapping is stored as a binary value, i.e. 1001
bytes := strconv.FormatInt(parsedInt, 2)
numValues := len(bytes)
// if the bit is ON, that means the value at that index is included in the SET
for i, char := range bytes {
if char == '1' {
// bytes are in reverse order, the first bit represents the last value in the SET
mappedValue := values[numValues-(i+1)]
mappedValues = append([]string{mappedValue}, mappedValues...)
}
}

// If we can't find the values, just return the original value
if len(mappedValues) == 0 {
return value
}

mappedValue, _ := sqltypes.NewValue(value.Type(), []byte(strings.Join(mappedValues, ",")))
return mappedValue
}

func mapEnumValue(value sqltypes.Value, values []string) sqltypes.Value {
parsedValue := value.ToString()
index, err := strconv.ParseInt(parsedValue, 10, 64)
if err != nil {
// If value is not an integer (index), we just return the original value
return value
}

// The index value of the empty string error value is 0
if index == 0 {
emptyValue, _ := sqltypes.NewValue(value.Type(), []byte(""))
return emptyValue
}

for i, v := range values {
if int(index-1) == i {
mappedValue, _ := sqltypes.NewValue(value.Type(), []byte(v))
return mappedValue
}
}

// Just return the original value if we can't find the enum value
return value
}

type AirbyteState struct {
Data SyncState `json:"data"`
}
Expand Down
34 changes: 33 additions & 1 deletion cmd/internal/types_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@ package internal

import (
"encoding/base64"
"testing"

psdbconnect "github.com/planetscale/airbyte-source/proto/psdbconnect/v1alpha1"
"github.com/planetscale/psdb/core/codec"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"testing"
"vitess.io/vitess/go/sqltypes"
"vitess.io/vitess/go/vt/proto/query"
)
Expand Down Expand Up @@ -76,3 +77,34 @@ func TestCanUnmarshalLastKnownState(t *testing.T) {
assert.Equal(t, "THIS_IS_A_GTID", tc.Position)
assert.Equal(t, lastKnownPK, tc.LastKnownPk)
}

func TestCanMapEnumAndSetValues(t *testing.T) {
indexEnumValue, err := sqltypes.NewValue(query.Type_ENUM, []byte("1"))
assert.NoError(t, err)
indexSetValue, err := sqltypes.NewValue(query.Type_SET, []byte("24")) // 24 is decimal conversion of 11000 in binary
assert.NoError(t, err)
mappedEnumValue, err := sqltypes.NewValue(query.Type_ENUM, []byte("active"))
assert.NoError(t, err)
mappedSetValue, err := sqltypes.NewValue(query.Type_SET, []byte("San Francisco,Oakland"))
assert.NoError(t, err)
input := sqltypes.Result{
Fields: []*query.Field{
{Name: "customer_id", Type: sqltypes.Int64, ColumnType: "bigint"},
{Name: "status", Type: sqltypes.Enum, ColumnType: "enum('active','inactive')"},
{Name: "locations", Type: sqltypes.Set, ColumnType: "set('San Francisco','New York','London','San Jose','Oakland')"},
},
Rows: [][]sqltypes.Value{
{sqltypes.NewInt64(1), indexEnumValue, indexSetValue},
{sqltypes.NewInt64(2), mappedEnumValue, mappedSetValue},
},
}

output := QueryResultToRecords(&input)
assert.Equal(t, 2, len(output))
firstRow := output[0]
assert.Equal(t, "active", firstRow["status"].(sqltypes.Value).ToString())
assert.Equal(t, "San Jose,Oakland", firstRow["locations"].(sqltypes.Value).ToString())
secondRow := output[1]
assert.Equal(t, "active", secondRow["status"].(sqltypes.Value).ToString())
assert.Equal(t, "San Francisco,Oakland", secondRow["locations"].(sqltypes.Value).ToString())
}

0 comments on commit 54a459f

Please sign in to comment.