Skip to content

Commit

Permalink
*: Restrict invalid vectors
Browse files Browse the repository at this point in the history
Signed-off-by: Wish <[email protected]>
  • Loading branch information
breezewish committed Sep 25, 2024
1 parent a5e07a2 commit 59d31e7
Show file tree
Hide file tree
Showing 4 changed files with 232 additions and 1 deletion.
17 changes: 17 additions & 0 deletions pkg/ddl/add_column.go
Original file line number Diff line number Diff line change
Expand Up @@ -766,6 +766,17 @@ func getFuncCallDefaultValue(col *table.Column, option *ast.ColumnOption, expr *
col.DefaultIsExpr = true
return str, false, nil

case ast.VecFromText:
if err := expression.VerifyArgsWrapper(expr.FnName.L, len(expr.Args)); err != nil {
return nil, false, errors.Trace(err)
}
str, err := restoreFuncCall(expr)
if err != nil {
return nil, false, errors.Trace(err)
}
col.DefaultIsExpr = true
return str, false, nil

default:
return nil, false, dbterror.ErrDefValGeneratedNamedFunctionIsNotAllowed.GenWithStackByArgs(col.Name.String(), expr.FnName.String())
}
Expand Down Expand Up @@ -1174,6 +1185,12 @@ func checkColumnValueConstraint(col *table.Column, collation string) error {
// In NO_ZERO_DATE SQL mode, TIMESTAMP/DATE/DATETIME type can't have zero date like '0000-00-00' or '0000-00-00 00:00:00'.
func checkColumnDefaultValue(ctx exprctx.BuildContext, col *table.Column, value any) (bool, any, error) {
hasDefaultValue := true

if value != nil && col.GetType() == mysql.TypeTiDBVectorFloat32 {
// In any SQL mode we don't allow VECTOR column to have a default value.
// Note that expression default is still supported.
return hasDefaultValue, value, errors.Errorf("VECTOR column '%-.192s' can't have a default value", col.Name.O)
}
if value != nil && (col.GetType() == mysql.TypeJSON ||
col.GetType() == mysql.TypeTinyBlob || col.GetType() == mysql.TypeMediumBlob ||
col.GetType() == mysql.TypeLongBlob || col.GetType() == mysql.TypeBlob) {
Expand Down
2 changes: 1 addition & 1 deletion pkg/expression/integration_test/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ go_test(
"main_test.go",
],
flaky = True,
shard_count = 43,
shard_count = 44,
deps = [
"//pkg/config",
"//pkg/domain",
Expand Down
209 changes: 209 additions & 0 deletions pkg/expression/integration_test/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,215 @@ import (
"github.com/tikv/client-go/v2/oracle"
)

func TestVectorDefaultValue(t *testing.T) {
store := testkit.CreateMockStore(t)
tk := testkit.NewTestKit(t, store)
tk.MustExec("use test")

// ============================
// NULLABLE, NO DEFAULT, Non-Strict Mode

tk.MustExec("set @@session.sql_mode=''")
tk.MustExec("create table t(embedding VECTOR)")
tk.MustExec("insert into t values ('[1,2,3]')")
tk.MustExec("insert into t values ('[4]')")
tk.MustExec("insert into t values (DEFAULT)")
tk.MustExec("insert into t values (NULL)")
tk.MustQuery("select *, vec_dims(embedding) from t").Check(testkit.Rows(
"[1,2,3] 3",
"[4] 1",
"<nil> <nil>",
"<nil> <nil>",
))
tk.MustExec("drop table t")

tk.MustExec("create table t(embedding VECTOR(3))")
tk.MustExec("insert into t values ('[1,2,3]')")
tk.MustExecToErr("insert into t values ('[4]')")
tk.MustExec("insert into t values (DEFAULT)")
tk.MustExec("insert into t values (NULL)")
tk.MustQuery("select *, vec_dims(embedding) from t").Check(testkit.Rows(
"[1,2,3] 3",
"<nil> <nil>",
"<nil> <nil>",
))
tk.MustExec("drop table t")

// ============================
// NULLABLE, NO DEFAULT, Strict Mode

tk.MustExec("set @@session.sql_mode='STRICT_ALL_TABLES'")
tk.MustExec("create table t(embedding VECTOR)")
tk.MustExec("insert into t values ('[1,2,3]')")
tk.MustExec("insert into t values ('[4]')")
tk.MustExec("insert into t values (DEFAULT)")
tk.MustExec("insert into t values (NULL)")
tk.MustQuery("select *, vec_dims(embedding) from t").Check(testkit.Rows(
"[1,2,3] 3",
"[4] 1",
"<nil> <nil>",
"<nil> <nil>",
))
tk.MustExec("drop table t")

tk.MustExec("create table t(embedding VECTOR(3))")
tk.MustExec("insert into t values ('[1,2,3]')")
tk.MustExecToErr("insert into t values ('[4]')")
tk.MustExec("insert into t values (DEFAULT)")
tk.MustExec("insert into t values (NULL)")
tk.MustQuery("select *, vec_dims(embedding) from t").Check(testkit.Rows(
"[1,2,3] 3",
"<nil> <nil>",
"<nil> <nil>",
))
tk.MustExec("drop table t")

// ============================
// NULLABLE, DEFAULT

tk.MustGetErrMsg("create table t(embedding VECTOR DEFAULT '[1,2,3]')", `VECTOR column 'embedding' can't have a default value`)
tk.MustExec("create table t(embedding VECTOR DEFAULT (VEC_FROM_TEXT('[1,2,3]')))")
tk.MustExec("insert into t values ()")
tk.MustExec("insert into t values ('[4]')")
tk.MustExec("insert into t values (DEFAULT)")
tk.MustExec("insert into t values (NULL)")
tk.MustExec("insert into t values (DEFAULT(embedding))")
tk.MustQuery("select *, vec_dims(embedding) from t").Check(testkit.Rows(
"[1,2,3] 3",
"[4] 1",
"[1,2,3] 3",
"<nil> <nil>",
"[1,2,3] 3",
))
tk.MustExec("drop table t")

// Allow. Error happens when inserting.
tk.MustExec("create table t(embedding VECTOR(5) DEFAULT (VEC_FROM_TEXT('[1,2,3]')))")
tk.MustGetErrMsg("insert into t values ()", "vector has 3 dimensions, does not fit VECTOR(5)")
tk.MustGetErrMsg("insert into t values (DEFAULT)", "vector has 3 dimensions, does not fit VECTOR(5)")
tk.MustExec("insert into t values ('[1,2,3,4,5]')")
tk.MustExec("insert into t values (NULL)")
tk.MustQuery("select *, vec_dims(embedding) from t").Check(testkit.Rows("[1,2,3,4,5] 5", "<nil> <nil>"))
tk.MustExec("drop table t")

// Allow. Error happens when inserting.
tk.MustExec("create table t(embedding VECTOR(5) DEFAULT (UUID()))")
tk.MustContainErrMsg("insert into t values ()", "Invalid vector text: ")
tk.MustExec("insert into t values ('[1,2,3,4,5]')")
tk.MustExec("insert into t values (NULL)")
tk.MustQuery("select *, vec_dims(embedding) from t").Check(testkit.Rows("[1,2,3,4,5] 5", "<nil> <nil>"))
tk.MustExec("drop table t")

// ============================
// NOT NULL, NO DEFAULT, Non-Strict Mode

tk.MustExec("set @@session.sql_mode=''")
tk.MustExec("create table t(embedding VECTOR NOT NULL)")
tk.MustExec("insert into t values ('[1,2,3]')")
tk.MustExec("insert into t values ('[4]')")
tk.MustExec("insert into t values (DEFAULT)")
tk.MustExecToErr("insert into t values (NULL)")
tk.MustQuery("select *, vec_dims(embedding) from t").Check(testkit.Rows(
"[1,2,3] 3",
"[4] 1",
"[] 0",
))
tk.MustExec("drop table t")

tk.MustExec("create table t(embedding VECTOR(3) NOT NULL)")
tk.MustExec("insert into t values ('[1,2,3]')")
tk.MustExecToErr("insert into t values ('[4]')")
tk.MustExecToErr("insert into t values (DEFAULT)")
tk.MustExecToErr("insert into t values (NULL)")
tk.MustQuery("select *, vec_dims(embedding) from t").Check(testkit.Rows(
"[1,2,3] 3",
))
tk.MustExec("drop table t")

// ============================
// NOT NULL, NO DEFAULT, Strict Mode

tk.MustExec("set @@session.sql_mode='STRICT_ALL_TABLES'")
tk.MustExec("create table t(embedding VECTOR NOT NULL)")
tk.MustExec("insert into t values ('[1,2,3]')")
tk.MustExec("insert into t values ('[4]')")
tk.MustExecToErr("insert into t values (DEFAULT)")
tk.MustExecToErr("insert into t values (NULL)")
tk.MustQuery("select *, vec_dims(embedding) from t").Check(testkit.Rows(
"[1,2,3] 3",
"[4] 1",
))
tk.MustExec("drop table t")

tk.MustExec("create table t(embedding VECTOR(3) NOT NULL)")
tk.MustExec("insert into t values ('[1,2,3]')")
tk.MustExecToErr("insert into t values ('[4]')")
tk.MustExecToErr("insert into t values (DEFAULT)")
tk.MustExecToErr("insert into t values (NULL)")
tk.MustQuery("select *, vec_dims(embedding) from t").Check(testkit.Rows(
"[1,2,3] 3",
))
tk.MustExec("drop table t")

// ============================
// NOT NULL, DEFAULT, Non-Strict Mode

tk.MustExec("set @@session.sql_mode=''")
tk.MustExec("create table t(embedding VECTOR NOT NULL DEFAULT (VEC_FROM_TEXT('[1,2,3]')))")
tk.MustExec("insert into t values ()")
tk.MustExec("insert into t values ('[4]')")
tk.MustExec("insert into t values (DEFAULT)")
tk.MustExecToErr("insert into t values (NULL)")
tk.MustExec("insert into t values (DEFAULT(embedding))")
tk.MustQuery("select *, vec_dims(embedding) from t").Check(testkit.Rows(
"[1,2,3] 3",
"[4] 1",
"[1,2,3] 3",
"[1,2,3] 3",
))
tk.MustExec("drop table t")

tk.MustExec("create table t(embedding VECTOR(1) NOT NULL DEFAULT (VEC_FROM_TEXT('[1,2,3]')))")
tk.MustExecToErr("insert into t values ()")
tk.MustExec("insert into t values ('[4]')")
tk.MustExecToErr("insert into t values (DEFAULT)")
tk.MustExecToErr("insert into t values (NULL)")
tk.MustExecToErr("insert into t values (DEFAULT(embedding))")
tk.MustQuery("select *, vec_dims(embedding) from t").Check(testkit.Rows(
"[4] 1",
))
tk.MustExec("drop table t")

// ============================
// NOT NULL, DEFAULT, Strict Mode

tk.MustExec("set @@session.sql_mode='STRICT_ALL_TABLES'")
tk.MustExec("create table t(embedding VECTOR NOT NULL DEFAULT (VEC_FROM_TEXT('[1,2,3]')))")
tk.MustExec("insert into t values ()")
tk.MustExec("insert into t values ('[4]')")
tk.MustExec("insert into t values (DEFAULT)")
tk.MustExecToErr("insert into t values (NULL)")
tk.MustExec("insert into t values (DEFAULT(embedding))")
tk.MustQuery("select *, vec_dims(embedding) from t").Check(testkit.Rows(
"[1,2,3] 3",
"[4] 1",
"[1,2,3] 3",
"[1,2,3] 3",
))
tk.MustExec("drop table t")

tk.MustExec("create table t(embedding VECTOR(1) NOT NULL DEFAULT (VEC_FROM_TEXT('[1,2,3]')))")
tk.MustExecToErr("insert into t values ()")
tk.MustExec("insert into t values ('[4]')")
tk.MustExecToErr("insert into t values (DEFAULT)")
tk.MustExecToErr("insert into t values (NULL)")
tk.MustExecToErr("insert into t values (DEFAULT(embedding))")
tk.MustQuery("select *, vec_dims(embedding) from t").Check(testkit.Rows(
"[4] 1",
))
tk.MustExec("drop table t")
}

func TestVectorColumnInfo(t *testing.T) {
store := testkit.CreateMockStore(t)
tk := testkit.NewTestKit(t, store)
Expand Down
5 changes: 5 additions & 0 deletions pkg/table/column.go
Original file line number Diff line number Diff line change
Expand Up @@ -493,6 +493,11 @@ func CheckOnce(cols []*Column) error {
// Otherwise, it will return a ErrColumnCantNull when error.
func (c *Column) CheckNotNull(data *types.Datum, rowCntInLoadData uint64) error {
if (mysql.HasNotNullFlag(c.GetFlag()) || mysql.HasPreventNullInsertFlag(c.GetFlag())) && data.IsNull() {
if c.FieldType.EvalType().IsVectorKind() {
// Vector(N) is a special case because it does not have zero values as a fallback.
// So we always reject NULLs in NotNull context even if SQL mode is not strict.
return errors.Errorf("VECTOR column '%s' cannot be null", c.Name)
}
if rowCntInLoadData > 0 {
return ErrWarnNullToNotnull.GenWithStackByArgs(c.Name, rowCntInLoadData)
}
Expand Down

0 comments on commit 59d31e7

Please sign in to comment.