From 59d31e7a9f6e781e005dd0db519ec5ab8c91f7b3 Mon Sep 17 00:00:00 2001 From: Wish Date: Wed, 25 Sep 2024 14:42:47 +0800 Subject: [PATCH] *: Restrict invalid vectors Signed-off-by: Wish --- pkg/ddl/add_column.go | 17 ++ pkg/expression/integration_test/BUILD.bazel | 2 +- .../integration_test/integration_test.go | 209 ++++++++++++++++++ pkg/table/column.go | 5 + 4 files changed, 232 insertions(+), 1 deletion(-) diff --git a/pkg/ddl/add_column.go b/pkg/ddl/add_column.go index 79e4170f0097e..dd64bb674e55d 100644 --- a/pkg/ddl/add_column.go +++ b/pkg/ddl/add_column.go @@ -766,6 +766,17 @@ func getFuncCallDefaultValue(col *table.Column, option *ast.ColumnOption, expr * col.DefaultIsExpr = true return str, false, nil + case ast.VecFromText: + if err := expression.VerifyArgsWrapper(expr.FnName.L, len(expr.Args)); err != nil { + return nil, false, errors.Trace(err) + } + str, err := restoreFuncCall(expr) + if err != nil { + return nil, false, errors.Trace(err) + } + col.DefaultIsExpr = true + return str, false, nil + default: return nil, false, dbterror.ErrDefValGeneratedNamedFunctionIsNotAllowed.GenWithStackByArgs(col.Name.String(), expr.FnName.String()) } @@ -1174,6 +1185,12 @@ func checkColumnValueConstraint(col *table.Column, collation string) error { // In NO_ZERO_DATE SQL mode, TIMESTAMP/DATE/DATETIME type can't have zero date like '0000-00-00' or '0000-00-00 00:00:00'. func checkColumnDefaultValue(ctx exprctx.BuildContext, col *table.Column, value any) (bool, any, error) { hasDefaultValue := true + + if value != nil && col.GetType() == mysql.TypeTiDBVectorFloat32 { + // In any SQL mode we don't allow VECTOR column to have a default value. + // Note that expression default is still supported. + return hasDefaultValue, value, errors.Errorf("VECTOR column '%-.192s' can't have a default value", col.Name.O) + } if value != nil && (col.GetType() == mysql.TypeJSON || col.GetType() == mysql.TypeTinyBlob || col.GetType() == mysql.TypeMediumBlob || col.GetType() == mysql.TypeLongBlob || col.GetType() == mysql.TypeBlob) { diff --git a/pkg/expression/integration_test/BUILD.bazel b/pkg/expression/integration_test/BUILD.bazel index b99f7815f8fd2..516b5048f6c9e 100644 --- a/pkg/expression/integration_test/BUILD.bazel +++ b/pkg/expression/integration_test/BUILD.bazel @@ -8,7 +8,7 @@ go_test( "main_test.go", ], flaky = True, - shard_count = 43, + shard_count = 44, deps = [ "//pkg/config", "//pkg/domain", diff --git a/pkg/expression/integration_test/integration_test.go b/pkg/expression/integration_test/integration_test.go index db9077abd246f..52428d5548904 100644 --- a/pkg/expression/integration_test/integration_test.go +++ b/pkg/expression/integration_test/integration_test.go @@ -60,6 +60,215 @@ import ( "github.com/tikv/client-go/v2/oracle" ) +func TestVectorDefaultValue(t *testing.T) { + store := testkit.CreateMockStore(t) + tk := testkit.NewTestKit(t, store) + tk.MustExec("use test") + + // ============================ + // NULLABLE, NO DEFAULT, Non-Strict Mode + + tk.MustExec("set @@session.sql_mode=''") + tk.MustExec("create table t(embedding VECTOR)") + tk.MustExec("insert into t values ('[1,2,3]')") + tk.MustExec("insert into t values ('[4]')") + tk.MustExec("insert into t values (DEFAULT)") + tk.MustExec("insert into t values (NULL)") + tk.MustQuery("select *, vec_dims(embedding) from t").Check(testkit.Rows( + "[1,2,3] 3", + "[4] 1", + " ", + " ", + )) + tk.MustExec("drop table t") + + tk.MustExec("create table t(embedding VECTOR(3))") + tk.MustExec("insert into t values ('[1,2,3]')") + tk.MustExecToErr("insert into t values ('[4]')") + tk.MustExec("insert into t values (DEFAULT)") + tk.MustExec("insert into t values (NULL)") + tk.MustQuery("select *, vec_dims(embedding) from t").Check(testkit.Rows( + "[1,2,3] 3", + " ", + " ", + )) + tk.MustExec("drop table t") + + // ============================ + // NULLABLE, NO DEFAULT, Strict Mode + + tk.MustExec("set @@session.sql_mode='STRICT_ALL_TABLES'") + tk.MustExec("create table t(embedding VECTOR)") + tk.MustExec("insert into t values ('[1,2,3]')") + tk.MustExec("insert into t values ('[4]')") + tk.MustExec("insert into t values (DEFAULT)") + tk.MustExec("insert into t values (NULL)") + tk.MustQuery("select *, vec_dims(embedding) from t").Check(testkit.Rows( + "[1,2,3] 3", + "[4] 1", + " ", + " ", + )) + tk.MustExec("drop table t") + + tk.MustExec("create table t(embedding VECTOR(3))") + tk.MustExec("insert into t values ('[1,2,3]')") + tk.MustExecToErr("insert into t values ('[4]')") + tk.MustExec("insert into t values (DEFAULT)") + tk.MustExec("insert into t values (NULL)") + tk.MustQuery("select *, vec_dims(embedding) from t").Check(testkit.Rows( + "[1,2,3] 3", + " ", + " ", + )) + tk.MustExec("drop table t") + + // ============================ + // NULLABLE, DEFAULT + + tk.MustGetErrMsg("create table t(embedding VECTOR DEFAULT '[1,2,3]')", `VECTOR column 'embedding' can't have a default value`) + tk.MustExec("create table t(embedding VECTOR DEFAULT (VEC_FROM_TEXT('[1,2,3]')))") + tk.MustExec("insert into t values ()") + tk.MustExec("insert into t values ('[4]')") + tk.MustExec("insert into t values (DEFAULT)") + tk.MustExec("insert into t values (NULL)") + tk.MustExec("insert into t values (DEFAULT(embedding))") + tk.MustQuery("select *, vec_dims(embedding) from t").Check(testkit.Rows( + "[1,2,3] 3", + "[4] 1", + "[1,2,3] 3", + " ", + "[1,2,3] 3", + )) + tk.MustExec("drop table t") + + // Allow. Error happens when inserting. + tk.MustExec("create table t(embedding VECTOR(5) DEFAULT (VEC_FROM_TEXT('[1,2,3]')))") + tk.MustGetErrMsg("insert into t values ()", "vector has 3 dimensions, does not fit VECTOR(5)") + tk.MustGetErrMsg("insert into t values (DEFAULT)", "vector has 3 dimensions, does not fit VECTOR(5)") + tk.MustExec("insert into t values ('[1,2,3,4,5]')") + tk.MustExec("insert into t values (NULL)") + tk.MustQuery("select *, vec_dims(embedding) from t").Check(testkit.Rows("[1,2,3,4,5] 5", " ")) + tk.MustExec("drop table t") + + // Allow. Error happens when inserting. + tk.MustExec("create table t(embedding VECTOR(5) DEFAULT (UUID()))") + tk.MustContainErrMsg("insert into t values ()", "Invalid vector text: ") + tk.MustExec("insert into t values ('[1,2,3,4,5]')") + tk.MustExec("insert into t values (NULL)") + tk.MustQuery("select *, vec_dims(embedding) from t").Check(testkit.Rows("[1,2,3,4,5] 5", " ")) + tk.MustExec("drop table t") + + // ============================ + // NOT NULL, NO DEFAULT, Non-Strict Mode + + tk.MustExec("set @@session.sql_mode=''") + tk.MustExec("create table t(embedding VECTOR NOT NULL)") + tk.MustExec("insert into t values ('[1,2,3]')") + tk.MustExec("insert into t values ('[4]')") + tk.MustExec("insert into t values (DEFAULT)") + tk.MustExecToErr("insert into t values (NULL)") + tk.MustQuery("select *, vec_dims(embedding) from t").Check(testkit.Rows( + "[1,2,3] 3", + "[4] 1", + "[] 0", + )) + tk.MustExec("drop table t") + + tk.MustExec("create table t(embedding VECTOR(3) NOT NULL)") + tk.MustExec("insert into t values ('[1,2,3]')") + tk.MustExecToErr("insert into t values ('[4]')") + tk.MustExecToErr("insert into t values (DEFAULT)") + tk.MustExecToErr("insert into t values (NULL)") + tk.MustQuery("select *, vec_dims(embedding) from t").Check(testkit.Rows( + "[1,2,3] 3", + )) + tk.MustExec("drop table t") + + // ============================ + // NOT NULL, NO DEFAULT, Strict Mode + + tk.MustExec("set @@session.sql_mode='STRICT_ALL_TABLES'") + tk.MustExec("create table t(embedding VECTOR NOT NULL)") + tk.MustExec("insert into t values ('[1,2,3]')") + tk.MustExec("insert into t values ('[4]')") + tk.MustExecToErr("insert into t values (DEFAULT)") + tk.MustExecToErr("insert into t values (NULL)") + tk.MustQuery("select *, vec_dims(embedding) from t").Check(testkit.Rows( + "[1,2,3] 3", + "[4] 1", + )) + tk.MustExec("drop table t") + + tk.MustExec("create table t(embedding VECTOR(3) NOT NULL)") + tk.MustExec("insert into t values ('[1,2,3]')") + tk.MustExecToErr("insert into t values ('[4]')") + tk.MustExecToErr("insert into t values (DEFAULT)") + tk.MustExecToErr("insert into t values (NULL)") + tk.MustQuery("select *, vec_dims(embedding) from t").Check(testkit.Rows( + "[1,2,3] 3", + )) + tk.MustExec("drop table t") + + // ============================ + // NOT NULL, DEFAULT, Non-Strict Mode + + tk.MustExec("set @@session.sql_mode=''") + tk.MustExec("create table t(embedding VECTOR NOT NULL DEFAULT (VEC_FROM_TEXT('[1,2,3]')))") + tk.MustExec("insert into t values ()") + tk.MustExec("insert into t values ('[4]')") + tk.MustExec("insert into t values (DEFAULT)") + tk.MustExecToErr("insert into t values (NULL)") + tk.MustExec("insert into t values (DEFAULT(embedding))") + tk.MustQuery("select *, vec_dims(embedding) from t").Check(testkit.Rows( + "[1,2,3] 3", + "[4] 1", + "[1,2,3] 3", + "[1,2,3] 3", + )) + tk.MustExec("drop table t") + + tk.MustExec("create table t(embedding VECTOR(1) NOT NULL DEFAULT (VEC_FROM_TEXT('[1,2,3]')))") + tk.MustExecToErr("insert into t values ()") + tk.MustExec("insert into t values ('[4]')") + tk.MustExecToErr("insert into t values (DEFAULT)") + tk.MustExecToErr("insert into t values (NULL)") + tk.MustExecToErr("insert into t values (DEFAULT(embedding))") + tk.MustQuery("select *, vec_dims(embedding) from t").Check(testkit.Rows( + "[4] 1", + )) + tk.MustExec("drop table t") + + // ============================ + // NOT NULL, DEFAULT, Strict Mode + + tk.MustExec("set @@session.sql_mode='STRICT_ALL_TABLES'") + tk.MustExec("create table t(embedding VECTOR NOT NULL DEFAULT (VEC_FROM_TEXT('[1,2,3]')))") + tk.MustExec("insert into t values ()") + tk.MustExec("insert into t values ('[4]')") + tk.MustExec("insert into t values (DEFAULT)") + tk.MustExecToErr("insert into t values (NULL)") + tk.MustExec("insert into t values (DEFAULT(embedding))") + tk.MustQuery("select *, vec_dims(embedding) from t").Check(testkit.Rows( + "[1,2,3] 3", + "[4] 1", + "[1,2,3] 3", + "[1,2,3] 3", + )) + tk.MustExec("drop table t") + + tk.MustExec("create table t(embedding VECTOR(1) NOT NULL DEFAULT (VEC_FROM_TEXT('[1,2,3]')))") + tk.MustExecToErr("insert into t values ()") + tk.MustExec("insert into t values ('[4]')") + tk.MustExecToErr("insert into t values (DEFAULT)") + tk.MustExecToErr("insert into t values (NULL)") + tk.MustExecToErr("insert into t values (DEFAULT(embedding))") + tk.MustQuery("select *, vec_dims(embedding) from t").Check(testkit.Rows( + "[4] 1", + )) + tk.MustExec("drop table t") +} + func TestVectorColumnInfo(t *testing.T) { store := testkit.CreateMockStore(t) tk := testkit.NewTestKit(t, store) diff --git a/pkg/table/column.go b/pkg/table/column.go index cde8588ea524f..4b20df2cab7ee 100644 --- a/pkg/table/column.go +++ b/pkg/table/column.go @@ -493,6 +493,11 @@ func CheckOnce(cols []*Column) error { // Otherwise, it will return a ErrColumnCantNull when error. func (c *Column) CheckNotNull(data *types.Datum, rowCntInLoadData uint64) error { if (mysql.HasNotNullFlag(c.GetFlag()) || mysql.HasPreventNullInsertFlag(c.GetFlag())) && data.IsNull() { + if c.FieldType.EvalType().IsVectorKind() { + // Vector(N) is a special case because it does not have zero values as a fallback. + // So we always reject NULLs in NotNull context even if SQL mode is not strict. + return errors.Errorf("VECTOR column '%s' cannot be null", c.Name) + } if rowCntInLoadData > 0 { return ErrWarnNullToNotnull.GenWithStackByArgs(c.Name, rowCntInLoadData) }