From db17e31f5f878cd14a526e24bf96fff5c3e1ea51 Mon Sep 17 00:00:00 2001 From: Lloyd-Pottiger <60744015+Lloyd-Pottiger@users.noreply.github.com> Date: Fri, 20 Sep 2024 17:58:09 +0800 Subject: [PATCH] ddl: Support parsing VectorIndex defined in IndexInfo (#9448) ref pingcap/tiflash#9032 ddl: Support parsing VectorIndex defined in IndexInfo Co-authored-by: JaySon --- dbms/src/TiDB/Schema/TiDB.cpp | 95 ++++++++++++------- dbms/src/TiDB/Schema/TiDB.h | 3 + .../TiDB/Schema/tests/gtest_table_info.cpp | 92 ++++++++++++++++++ 3 files changed, 154 insertions(+), 36 deletions(-) diff --git a/dbms/src/TiDB/Schema/TiDB.cpp b/dbms/src/TiDB/Schema/TiDB.cpp index d74be73795a..471c944b2f0 100644 --- a/dbms/src/TiDB/Schema/TiDB.cpp +++ b/dbms/src/TiDB/Schema/TiDB.cpp @@ -104,6 +104,47 @@ using DB::Exception; using DB::Field; using DB::SchemaNameMapper; +VectorIndexDefinitionPtr parseVectorIndexFromJSON(const Poco::JSON::Object::Ptr & json) +{ + assert(json); // not nullptr + + tipb::VectorIndexKind kind = tipb::VectorIndexKind::INVALID_INDEX_KIND; + auto kind_field = json->getValue("kind"); + RUNTIME_CHECK_MSG(tipb::VectorIndexKind_Parse(kind_field, &kind), "invalid kind of vector index, {}", kind_field); + RUNTIME_CHECK(kind != tipb::VectorIndexKind::INVALID_INDEX_KIND); + + auto dimension = json->getValue("dimension"); + RUNTIME_CHECK(dimension > 0); + RUNTIME_CHECK(dimension <= TiDB::MAX_VECTOR_DIMENSION); // Just a protection + + tipb::VectorDistanceMetric distance_metric = tipb::VectorDistanceMetric::INVALID_DISTANCE_METRIC; + auto distance_metric_field = json->getValue("distance_metric"); + RUNTIME_CHECK_MSG( + tipb::VectorDistanceMetric_Parse(distance_metric_field, &distance_metric), + "invalid distance_metric of vector index, {}", + distance_metric_field); + RUNTIME_CHECK(distance_metric != tipb::VectorDistanceMetric::INVALID_DISTANCE_METRIC); + + return std::make_shared(VectorIndexDefinition{ + .kind = kind, + .dimension = dimension, + .distance_metric = distance_metric, + }); +} + +Poco::JSON::Object::Ptr vectorIndexToJSON(const VectorIndexDefinitionPtr & vector_index) +{ + assert(vector_index != nullptr); + RUNTIME_CHECK(vector_index->kind != tipb::VectorIndexKind::INVALID_INDEX_KIND); + RUNTIME_CHECK(vector_index->distance_metric != tipb::VectorDistanceMetric::INVALID_DISTANCE_METRIC); + + Poco::JSON::Object::Ptr vector_index_json = new Poco::JSON::Object(); + vector_index_json->set("kind", tipb::VectorIndexKind_Name(vector_index->kind)); + vector_index_json->set("dimension", vector_index->dimension); + vector_index_json->set("distance_metric", tipb::VectorDistanceMetric_Name(vector_index->distance_metric)); + return vector_index_json; +} + //////////////////////// ////// ColumnInfo ////// //////////////////////// @@ -413,15 +454,7 @@ try if (vector_index) { - RUNTIME_CHECK(vector_index->kind != tipb::VectorIndexKind::INVALID_INDEX_KIND); - RUNTIME_CHECK(vector_index->distance_metric != tipb::VectorDistanceMetric::INVALID_DISTANCE_METRIC); - - Poco::JSON::Object::Ptr vector_index_json = new Poco::JSON::Object(); - vector_index_json->set("kind", tipb::VectorIndexKind_Name(vector_index->kind)); - vector_index_json->set("dimension", vector_index->dimension); - vector_index_json->set("distance_metric", tipb::VectorDistanceMetric_Name(vector_index->distance_metric)); - - json->set("vector_index", vector_index_json); + json->set("vector_index", vectorIndexToJSON(vector_index)); } #ifndef NDEBUG @@ -476,34 +509,9 @@ try } state = static_cast(json->getValue("state")); - auto vector_index_json = json->getObject("vector_index"); - if (vector_index_json) + if (auto vector_index_json = json->getObject("vector_index"); vector_index_json) { - tipb::VectorIndexKind kind = tipb::VectorIndexKind::INVALID_INDEX_KIND; - auto kind_field = vector_index_json->getValue("kind"); - auto ok = tipb::VectorIndexKind_Parse( // - kind_field, - &kind); - RUNTIME_CHECK_MSG(ok, "invalid kind of vector index, {}", kind_field); - RUNTIME_CHECK(kind != tipb::VectorIndexKind::INVALID_INDEX_KIND); - - auto dimension = vector_index_json->getValue("dimension"); - RUNTIME_CHECK(dimension > 0); - RUNTIME_CHECK(dimension <= TiDB::MAX_VECTOR_DIMENSION); // Just a protection - - tipb::VectorDistanceMetric distance_metric = tipb::VectorDistanceMetric::INVALID_DISTANCE_METRIC; - auto distance_metric_field = vector_index_json->getValue("distance_metric"); - ok = tipb::VectorDistanceMetric_Parse( // - distance_metric_field, - &distance_metric); - RUNTIME_CHECK_MSG(ok, "invalid distance_metric of vector index, {}", distance_metric_field); - RUNTIME_CHECK(distance_metric != tipb::VectorDistanceMetric::INVALID_DISTANCE_METRIC); - - vector_index = std::make_shared(VectorIndexDefinition{ - .kind = kind, - .dimension = dimension, - .distance_metric = distance_metric, - }); + vector_index = parseVectorIndexFromJSON(vector_index_json); } } catch (const Poco::Exception & e) @@ -846,6 +854,11 @@ try json->set("is_invisible", is_invisible); json->set("is_global", is_global); + if (vector_index) + { + json->set("vector_index", vectorIndexToJSON(vector_index)); + } + #ifndef NDEBUG std::stringstream str; json->stringify(str); @@ -886,6 +899,11 @@ try is_invisible = json->getValue("is_invisible"); if (json->has("is_global")) is_global = json->getValue("is_global"); + + if (auto vector_index_json = json->getObject("vector_index"); vector_index_json) + { + vector_index = parseVectorIndexFromJSON(vector_index_json); + } } catch (const Poco::Exception & e) { @@ -1024,6 +1042,10 @@ try // always put the primary_index at the front of all index_info index_infos.insert(index_infos.begin(), std::move(index_info)); } + else if (index_info.vector_index != nullptr) + { + index_infos.emplace_back(std::move(index_info)); + } } } @@ -1180,6 +1202,7 @@ const IndexInfo & TableInfo::getPrimaryIndexInfo() const #endif return index_infos[0]; } + size_t TableInfo::numColumnsInKey() const { if (pk_is_handle) diff --git a/dbms/src/TiDB/Schema/TiDB.h b/dbms/src/TiDB/Schema/TiDB.h index 7ef29e437bc..cc40dfed1a6 100644 --- a/dbms/src/TiDB/Schema/TiDB.h +++ b/dbms/src/TiDB/Schema/TiDB.h @@ -261,6 +261,8 @@ struct IndexInfo bool is_primary = false; bool is_invisible = false; bool is_global = false; + + VectorIndexDefinitionPtr vector_index = nullptr; }; struct TableInfo @@ -331,6 +333,7 @@ struct TableInfo /// should not be called if is_common_handle = false. const IndexInfo & getPrimaryIndexInfo() const; + size_t numColumnsInKey() const; }; diff --git a/dbms/src/TiDB/Schema/tests/gtest_table_info.cpp b/dbms/src/TiDB/Schema/tests/gtest_table_info.cpp index 5d622c3e240..f5233e5d7b6 100644 --- a/dbms/src/TiDB/Schema/tests/gtest_table_info.cpp +++ b/dbms/src/TiDB/Schema/tests/gtest_table_info.cpp @@ -131,6 +131,98 @@ try } CATCH +TEST(TiDBTableInfoTest, ParseVectorIndexJSON) +try +{ + auto cases = { + ParseCase{ + R"json({"cols":[{"default":null,"default_bit":null,"id":1,"name":{"L":"col1","O":"col1"},"offset":-1,"origin_default":null,"state":0,"type":{"Charset":null,"Collate":null,"Decimal":0,"Elems":null,"Flag":4097,"Flen":0,"Tp":8}},{"default":null,"default_bit":null,"id":2,"name":{"L":"vec","O":"vec"},"offset":-1,"origin_default":null,"state":0,"type":{"Charset":null,"Collate":null,"Decimal":0,"Elems":null,"Flag":4097,"Flen":0,"Tp":225}}],"id":30,"index_info":[{"id":3,"idx_cols":[{"length":-1,"name":{"L":"vec","O":"vec"},"offset":0}],"idx_name":{"L":"idx1","O":"idx1"},"index_type":-1,"is_global":false,"is_invisible":false,"is_primary":false,"is_unique":false,"state":5,"vector_index":{"dimension":3,"distance_metric":"L2","kind":"HNSW"}}],"is_common_handle":false,"name":{"L":"t1","O":"t1"},"partition":null,"pk_is_handle":false,"schema_version":-1,"state":0,"update_timestamp":1723778704444603})json", + [](const TableInfo & table_info) { + ASSERT_EQ(table_info.index_infos.size(), 1); + auto idx = table_info.index_infos[0]; + ASSERT_EQ(idx.id, 3); + ASSERT_EQ(idx.idx_cols.size(), 1); + ASSERT_EQ(idx.idx_cols[0].name, "vec"); + ASSERT_EQ(idx.idx_cols[0].offset, 0); + ASSERT_EQ(idx.idx_cols[0].length, -1); + ASSERT_NE(idx.vector_index, nullptr); + ASSERT_EQ(idx.vector_index->kind, tipb::VectorIndexKind::HNSW); + ASSERT_EQ(idx.vector_index->dimension, 3); + ASSERT_EQ(idx.vector_index->distance_metric, tipb::VectorDistanceMetric::L2); + ASSERT_EQ(table_info.columns.size(), 2); + auto col0 = table_info.columns[0]; + ASSERT_EQ(col0.name, "col1"); + ASSERT_EQ(col0.tp, TiDB::TP::TypeLongLong); + ASSERT_EQ(col0.id, 1); + auto col1 = table_info.columns[1]; + ASSERT_EQ(col1.name, "vec"); + ASSERT_EQ(col1.tp, TiDB::TP::TypeTiDBVectorFloat32); + ASSERT_EQ(col1.id, 2); + }, + }, + // VectorIndex defined in the ColumnInfo + ParseCase{ + R"json({"cols":[{"comment":"hnsw(distance=l2)","default":null,"default_bit":null,"id":1,"name":{"L":"v","O":"v"},"offset":0,"origin_default":null,"state":5,"type":{"Charset":"binary","Collate":"binary","Decimal":0,"Elems":null,"Flag":128,"Flen":5,"Tp":225},"vector_index":{"dimension":5,"distance_metric":"L2","kind":"HNSW"}}],"comment":"","id":96,"index_info":[],"is_common_handle":false,"keyspace_id":1,"name":{"L":"t","O":"t"},"partition":null,"pk_is_handle":false,"schema_version":-1,"state":5,"tiflash_replica":{"Available":true,"Count":1},"update_timestamp":451956855279976452})json", + [](const TableInfo & table_info) { + ASSERT_EQ(table_info.index_infos.size(), 0); + ASSERT_EQ(table_info.columns.size(), 1); + auto col = table_info.columns[0]; + ASSERT_EQ(col.name, "v"); + ASSERT_EQ(col.tp, TiDB::TP::TypeTiDBVectorFloat32); + ASSERT_EQ(col.id, 1); + auto vector_index_on_col = col.vector_index; + ASSERT_NE(vector_index_on_col, nullptr); + ASSERT_EQ(vector_index_on_col->kind, tipb::VectorIndexKind::HNSW); + ASSERT_EQ(vector_index_on_col->dimension, 5); + ASSERT_EQ(vector_index_on_col->distance_metric, tipb::VectorDistanceMetric::L2); + }, + }, + ParseCase{ + R"json({"cols":[{"comment":"","default":null,"default_bit":null,"id":1,"name":{"L":"col","O":"col"},"offset":0,"origin_default":null,"state":5,"type":{"Charset":"binary","Collate":"binary","Decimal":0,"Elems":null,"Flag":4099,"Flen":20,"Tp":8}},{"comment":"","default":null,"default_bit":null,"id":2,"name":{"L":"v","O":"v"},"offset":1,"origin_default":null,"state":5,"type":{"Charset":"binary","Collate":"binary","Decimal":0,"Elems":null,"Flag":128,"Flen":5,"Tp":225}}],"comment":"","id":96,"index_info":[{"id":4,"idx_cols":[{"length":-1,"name":{"L":"v","O":"v"},"offset":1}],"idx_name":{"L":"idx_v_l2","O":"idx_v_l2"},"index_type":5,"is_global":false,"is_invisible":false,"is_primary":false,"is_unique":false,"state":3,"vector_index":{"dimension":5,"distance_metric":"L2","kind":"HNSW"}},{"id":3,"idx_cols":[{"length":-1,"name":{"L":"col","O":"col"},"offset":0}],"idx_name":{"L":"primary","O":"primary"},"index_type":1,"is_global":false,"is_invisible":false,"is_primary":true,"is_unique":true,"state":5}],"is_common_handle":false,"keyspace_id":1,"name":{"L":"ti","O":"ti"},"partition":null,"pk_is_handle":false,"schema_version":-1,"state":5,"tiflash_replica":{"Available":true,"Count":1},"update_timestamp":452024291984670725})json", + [](const TableInfo & table_info) { + // vector index && primary index + // primary index alwasy be put at the first + ASSERT_EQ(table_info.index_infos.size(), 2); + auto idx0 = table_info.index_infos[0]; + ASSERT_TRUE(idx0.is_primary); + ASSERT_TRUE(idx0.is_unique); + ASSERT_EQ(idx0.id, 3); + ASSERT_EQ(idx0.idx_name, "primary"); + ASSERT_EQ(idx0.idx_cols.size(), 1); + ASSERT_EQ(idx0.idx_cols[0].name, "col"); + ASSERT_EQ(idx0.idx_cols[0].offset, 0); + ASSERT_EQ(idx0.vector_index, nullptr); + // vec index + auto idx1 = table_info.index_infos[1]; + ASSERT_EQ(idx1.id, 4); + ASSERT_EQ(idx1.idx_name, "idx_v_l2"); + ASSERT_EQ(idx1.idx_cols.size(), 1); + ASSERT_EQ(idx1.idx_cols[0].name, "v"); + ASSERT_EQ(idx1.idx_cols[0].offset, 1); + ASSERT_NE(idx1.vector_index, nullptr); + ASSERT_EQ(idx1.vector_index->kind, tipb::VectorIndexKind::HNSW); + ASSERT_EQ(idx1.vector_index->dimension, 5); + ASSERT_EQ(idx1.vector_index->distance_metric, tipb::VectorDistanceMetric::L2); + + ASSERT_EQ(table_info.columns.size(), 2); + auto col0 = table_info.columns[0]; + ASSERT_EQ(col0.name, "col"); + ASSERT_EQ(col0.tp, TiDB::TP::TypeLongLong); + ASSERT_EQ(col0.id, 1); + auto col1 = table_info.columns[1]; + ASSERT_EQ(col1.name, "v"); + ASSERT_EQ(col1.tp, TiDB::TP::TypeTiDBVectorFloat32); + ASSERT_EQ(col1.id, 2); + }}}; + + for (const auto & c : cases) + { + TableInfo table_info(c.table_info_json, NullspaceID); + c.check(table_info); + } +} +CATCH + struct StmtCase { TableID table_or_partition_id;