Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ddl: Support parsing VectorIndex defined in IndexInfo #9448

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 59 additions & 36 deletions dbms/src/TiDB/Schema/TiDB.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,47 @@ using DB::Exception;
using DB::Field;
using DB::SchemaNameMapper;

VectorIndexDefinitionPtr parseVectorIndexFromJSON(const Poco::JSON::Object::Ptr & json)
{
assert(json); // not nullptr

tipb::VectorIndexKind kind = tipb::VectorIndexKind::INVALID_INDEX_KIND;
auto kind_field = json->getValue<String>("kind");
RUNTIME_CHECK_MSG(tipb::VectorIndexKind_Parse(kind_field, &kind), "invalid kind of vector index, {}", kind_field);
RUNTIME_CHECK(kind != tipb::VectorIndexKind::INVALID_INDEX_KIND);

auto dimension = json->getValue<UInt64>("dimension");
RUNTIME_CHECK(dimension > 0);
RUNTIME_CHECK(dimension <= TiDB::MAX_VECTOR_DIMENSION); // Just a protection

tipb::VectorDistanceMetric distance_metric = tipb::VectorDistanceMetric::INVALID_DISTANCE_METRIC;
auto distance_metric_field = json->getValue<String>("distance_metric");
RUNTIME_CHECK_MSG(
tipb::VectorDistanceMetric_Parse(distance_metric_field, &distance_metric),
"invalid distance_metric of vector index, {}",
distance_metric_field);
RUNTIME_CHECK(distance_metric != tipb::VectorDistanceMetric::INVALID_DISTANCE_METRIC);

return std::make_shared<const VectorIndexDefinition>(VectorIndexDefinition{
.kind = kind,
.dimension = dimension,
.distance_metric = distance_metric,
});
}

Poco::JSON::Object::Ptr vectorIndexToJSON(const VectorIndexDefinitionPtr & vector_index)
{
assert(vector_index != nullptr);
RUNTIME_CHECK(vector_index->kind != tipb::VectorIndexKind::INVALID_INDEX_KIND);
RUNTIME_CHECK(vector_index->distance_metric != tipb::VectorDistanceMetric::INVALID_DISTANCE_METRIC);

Poco::JSON::Object::Ptr vector_index_json = new Poco::JSON::Object();
vector_index_json->set("kind", tipb::VectorIndexKind_Name(vector_index->kind));
vector_index_json->set("dimension", vector_index->dimension);
vector_index_json->set("distance_metric", tipb::VectorDistanceMetric_Name(vector_index->distance_metric));
return vector_index_json;
}

////////////////////////
////// ColumnInfo //////
////////////////////////
Expand Down Expand Up @@ -413,15 +454,7 @@ try

if (vector_index)
{
RUNTIME_CHECK(vector_index->kind != tipb::VectorIndexKind::INVALID_INDEX_KIND);
RUNTIME_CHECK(vector_index->distance_metric != tipb::VectorDistanceMetric::INVALID_DISTANCE_METRIC);

Poco::JSON::Object::Ptr vector_index_json = new Poco::JSON::Object();
vector_index_json->set("kind", tipb::VectorIndexKind_Name(vector_index->kind));
vector_index_json->set("dimension", vector_index->dimension);
vector_index_json->set("distance_metric", tipb::VectorDistanceMetric_Name(vector_index->distance_metric));

json->set("vector_index", vector_index_json);
json->set("vector_index", vectorIndexToJSON(vector_index));
}

#ifndef NDEBUG
Expand Down Expand Up @@ -476,34 +509,9 @@ try
}
state = static_cast<SchemaState>(json->getValue<Int32>("state"));

auto vector_index_json = json->getObject("vector_index");
if (vector_index_json)
if (auto vector_index_json = json->getObject("vector_index"); vector_index_json)
{
tipb::VectorIndexKind kind = tipb::VectorIndexKind::INVALID_INDEX_KIND;
auto kind_field = vector_index_json->getValue<String>("kind");
auto ok = tipb::VectorIndexKind_Parse( //
kind_field,
&kind);
RUNTIME_CHECK_MSG(ok, "invalid kind of vector index, {}", kind_field);
RUNTIME_CHECK(kind != tipb::VectorIndexKind::INVALID_INDEX_KIND);

auto dimension = vector_index_json->getValue<UInt64>("dimension");
RUNTIME_CHECK(dimension > 0);
RUNTIME_CHECK(dimension <= TiDB::MAX_VECTOR_DIMENSION); // Just a protection

tipb::VectorDistanceMetric distance_metric = tipb::VectorDistanceMetric::INVALID_DISTANCE_METRIC;
auto distance_metric_field = vector_index_json->getValue<String>("distance_metric");
ok = tipb::VectorDistanceMetric_Parse( //
distance_metric_field,
&distance_metric);
RUNTIME_CHECK_MSG(ok, "invalid distance_metric of vector index, {}", distance_metric_field);
RUNTIME_CHECK(distance_metric != tipb::VectorDistanceMetric::INVALID_DISTANCE_METRIC);

vector_index = std::make_shared<const VectorIndexDefinition>(VectorIndexDefinition{
.kind = kind,
.dimension = dimension,
.distance_metric = distance_metric,
});
vector_index = parseVectorIndexFromJSON(vector_index_json);
}
}
catch (const Poco::Exception & e)
Expand Down Expand Up @@ -846,6 +854,11 @@ try
json->set("is_invisible", is_invisible);
json->set("is_global", is_global);

if (vector_index)
{
json->set("vector_index", vectorIndexToJSON(vector_index));
}

#ifndef NDEBUG
std::stringstream str;
json->stringify(str);
Expand Down Expand Up @@ -886,6 +899,11 @@ try
is_invisible = json->getValue<bool>("is_invisible");
if (json->has("is_global"))
is_global = json->getValue<bool>("is_global");

if (auto vector_index_json = json->getObject("vector_index"); vector_index_json)
{
vector_index = parseVectorIndexFromJSON(vector_index_json);
}
}
catch (const Poco::Exception & e)
{
Expand Down Expand Up @@ -1024,6 +1042,10 @@ try
// always put the primary_index at the front of all index_info
index_infos.insert(index_infos.begin(), std::move(index_info));
}
else if (index_info.vector_index != nullptr)
{
index_infos.emplace_back(std::move(index_info));
}
}
}

Expand Down Expand Up @@ -1180,6 +1202,7 @@ const IndexInfo & TableInfo::getPrimaryIndexInfo() const
#endif
return index_infos[0];
}

size_t TableInfo::numColumnsInKey() const
{
if (pk_is_handle)
Expand Down
3 changes: 3 additions & 0 deletions dbms/src/TiDB/Schema/TiDB.h
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,8 @@ struct IndexInfo
bool is_primary = false;
bool is_invisible = false;
bool is_global = false;

VectorIndexDefinitionPtr vector_index = nullptr;
};

struct TableInfo
Expand Down Expand Up @@ -331,6 +333,7 @@ struct TableInfo

/// should not be called if is_common_handle = false.
const IndexInfo & getPrimaryIndexInfo() const;

size_t numColumnsInKey() const;
};

Expand Down
92 changes: 92 additions & 0 deletions dbms/src/TiDB/Schema/tests/gtest_table_info.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,98 @@ try
}
CATCH

TEST(TiDBTableInfoTest, ParseVectorIndexJSON)
try
{
auto cases = {
ParseCase{
R"json({"cols":[{"default":null,"default_bit":null,"id":1,"name":{"L":"col1","O":"col1"},"offset":-1,"origin_default":null,"state":0,"type":{"Charset":null,"Collate":null,"Decimal":0,"Elems":null,"Flag":4097,"Flen":0,"Tp":8}},{"default":null,"default_bit":null,"id":2,"name":{"L":"vec","O":"vec"},"offset":-1,"origin_default":null,"state":0,"type":{"Charset":null,"Collate":null,"Decimal":0,"Elems":null,"Flag":4097,"Flen":0,"Tp":225}}],"id":30,"index_info":[{"id":3,"idx_cols":[{"length":-1,"name":{"L":"vec","O":"vec"},"offset":0}],"idx_name":{"L":"idx1","O":"idx1"},"index_type":-1,"is_global":false,"is_invisible":false,"is_primary":false,"is_unique":false,"state":5,"vector_index":{"dimension":3,"distance_metric":"L2","kind":"HNSW"}}],"is_common_handle":false,"name":{"L":"t1","O":"t1"},"partition":null,"pk_is_handle":false,"schema_version":-1,"state":0,"update_timestamp":1723778704444603})json",
[](const TableInfo & table_info) {
ASSERT_EQ(table_info.index_infos.size(), 1);
auto idx = table_info.index_infos[0];
ASSERT_EQ(idx.id, 3);
ASSERT_EQ(idx.idx_cols.size(), 1);
ASSERT_EQ(idx.idx_cols[0].name, "vec");
ASSERT_EQ(idx.idx_cols[0].offset, 0);
ASSERT_EQ(idx.idx_cols[0].length, -1);
ASSERT_NE(idx.vector_index, nullptr);
ASSERT_EQ(idx.vector_index->kind, tipb::VectorIndexKind::HNSW);
ASSERT_EQ(idx.vector_index->dimension, 3);
ASSERT_EQ(idx.vector_index->distance_metric, tipb::VectorDistanceMetric::L2);
ASSERT_EQ(table_info.columns.size(), 2);
auto col0 = table_info.columns[0];
ASSERT_EQ(col0.name, "col1");
ASSERT_EQ(col0.tp, TiDB::TP::TypeLongLong);
ASSERT_EQ(col0.id, 1);
auto col1 = table_info.columns[1];
ASSERT_EQ(col1.name, "vec");
ASSERT_EQ(col1.tp, TiDB::TP::TypeTiDBVectorFloat32);
ASSERT_EQ(col1.id, 2);
},
},
// VectorIndex defined in the ColumnInfo
ParseCase{
R"json({"cols":[{"comment":"hnsw(distance=l2)","default":null,"default_bit":null,"id":1,"name":{"L":"v","O":"v"},"offset":0,"origin_default":null,"state":5,"type":{"Charset":"binary","Collate":"binary","Decimal":0,"Elems":null,"Flag":128,"Flen":5,"Tp":225},"vector_index":{"dimension":5,"distance_metric":"L2","kind":"HNSW"}}],"comment":"","id":96,"index_info":[],"is_common_handle":false,"keyspace_id":1,"name":{"L":"t","O":"t"},"partition":null,"pk_is_handle":false,"schema_version":-1,"state":5,"tiflash_replica":{"Available":true,"Count":1},"update_timestamp":451956855279976452})json",
[](const TableInfo & table_info) {
ASSERT_EQ(table_info.index_infos.size(), 0);
ASSERT_EQ(table_info.columns.size(), 1);
auto col = table_info.columns[0];
ASSERT_EQ(col.name, "v");
ASSERT_EQ(col.tp, TiDB::TP::TypeTiDBVectorFloat32);
ASSERT_EQ(col.id, 1);
auto vector_index_on_col = col.vector_index;
ASSERT_NE(vector_index_on_col, nullptr);
ASSERT_EQ(vector_index_on_col->kind, tipb::VectorIndexKind::HNSW);
ASSERT_EQ(vector_index_on_col->dimension, 5);
ASSERT_EQ(vector_index_on_col->distance_metric, tipb::VectorDistanceMetric::L2);
},
},
ParseCase{
R"json({"cols":[{"comment":"","default":null,"default_bit":null,"id":1,"name":{"L":"col","O":"col"},"offset":0,"origin_default":null,"state":5,"type":{"Charset":"binary","Collate":"binary","Decimal":0,"Elems":null,"Flag":4099,"Flen":20,"Tp":8}},{"comment":"","default":null,"default_bit":null,"id":2,"name":{"L":"v","O":"v"},"offset":1,"origin_default":null,"state":5,"type":{"Charset":"binary","Collate":"binary","Decimal":0,"Elems":null,"Flag":128,"Flen":5,"Tp":225}}],"comment":"","id":96,"index_info":[{"id":4,"idx_cols":[{"length":-1,"name":{"L":"v","O":"v"},"offset":1}],"idx_name":{"L":"idx_v_l2","O":"idx_v_l2"},"index_type":5,"is_global":false,"is_invisible":false,"is_primary":false,"is_unique":false,"state":3,"vector_index":{"dimension":5,"distance_metric":"L2","kind":"HNSW"}},{"id":3,"idx_cols":[{"length":-1,"name":{"L":"col","O":"col"},"offset":0}],"idx_name":{"L":"primary","O":"primary"},"index_type":1,"is_global":false,"is_invisible":false,"is_primary":true,"is_unique":true,"state":5}],"is_common_handle":false,"keyspace_id":1,"name":{"L":"ti","O":"ti"},"partition":null,"pk_is_handle":false,"schema_version":-1,"state":5,"tiflash_replica":{"Available":true,"Count":1},"update_timestamp":452024291984670725})json",
[](const TableInfo & table_info) {
// vector index && primary index
// primary index alwasy be put at the first
ASSERT_EQ(table_info.index_infos.size(), 2);
auto idx0 = table_info.index_infos[0];
ASSERT_TRUE(idx0.is_primary);
ASSERT_TRUE(idx0.is_unique);
ASSERT_EQ(idx0.id, 3);
ASSERT_EQ(idx0.idx_name, "primary");
ASSERT_EQ(idx0.idx_cols.size(), 1);
ASSERT_EQ(idx0.idx_cols[0].name, "col");
ASSERT_EQ(idx0.idx_cols[0].offset, 0);
ASSERT_EQ(idx0.vector_index, nullptr);
// vec index
auto idx1 = table_info.index_infos[1];
ASSERT_EQ(idx1.id, 4);
ASSERT_EQ(idx1.idx_name, "idx_v_l2");
ASSERT_EQ(idx1.idx_cols.size(), 1);
ASSERT_EQ(idx1.idx_cols[0].name, "v");
ASSERT_EQ(idx1.idx_cols[0].offset, 1);
ASSERT_NE(idx1.vector_index, nullptr);
ASSERT_EQ(idx1.vector_index->kind, tipb::VectorIndexKind::HNSW);
ASSERT_EQ(idx1.vector_index->dimension, 5);
ASSERT_EQ(idx1.vector_index->distance_metric, tipb::VectorDistanceMetric::L2);

ASSERT_EQ(table_info.columns.size(), 2);
auto col0 = table_info.columns[0];
ASSERT_EQ(col0.name, "col");
ASSERT_EQ(col0.tp, TiDB::TP::TypeLongLong);
ASSERT_EQ(col0.id, 1);
auto col1 = table_info.columns[1];
ASSERT_EQ(col1.name, "v");
ASSERT_EQ(col1.tp, TiDB::TP::TypeTiDBVectorFloat32);
ASSERT_EQ(col1.id, 2);
}}};

for (const auto & c : cases)
{
TableInfo table_info(c.table_info_json, NullspaceID);
c.check(table_info);
}
}
CATCH

struct StmtCase
{
TableID table_or_partition_id;
Expand Down