diff --git a/dbms/src/Storages/DeltaMerge/File/DMFileWithVectorIndexBlockInputStream.cpp b/dbms/src/Storages/DeltaMerge/File/DMFileWithVectorIndexBlockInputStream.cpp index 562b3756725..d0864ccb302 100644 --- a/dbms/src/Storages/DeltaMerge/File/DMFileWithVectorIndexBlockInputStream.cpp +++ b/dbms/src/Storages/DeltaMerge/File/DMFileWithVectorIndexBlockInputStream.cpp @@ -353,6 +353,7 @@ void DMFileWithVectorIndexBlockInputStream::loadVectorSearchResult() auto perf_begin = PerfContext::vector_search; + RUNTIME_CHECK(valid_rows.size() >= dmfile->getRows(), valid_rows.size(), dmfile->getRows()); auto results_rowid = vec_index->search(ann_query_info, valid_rows); auto discarded_nodes = PerfContext::vector_search.discarded_nodes - perf_begin.discarded_nodes; diff --git a/dbms/src/Storages/DeltaMerge/File/VectorColumnFromIndexReader.cpp b/dbms/src/Storages/DeltaMerge/File/VectorColumnFromIndexReader.cpp index ffb559bd33d..25d4842e75a 100644 --- a/dbms/src/Storages/DeltaMerge/File/VectorColumnFromIndexReader.cpp +++ b/dbms/src/Storages/DeltaMerge/File/VectorColumnFromIndexReader.cpp @@ -39,6 +39,8 @@ MutableColumnPtr VectorColumnFromIndexReader::calcResultsByPack( // results must be in ascending order. std::sort(results.begin(), results.end()); + // results must not contain duplicates. Usually there should be no duplicates. + results.erase(std::unique(results.begin(), results.end()), results.end()); std::vector offsets_in_pack; size_t results_it = 0; diff --git a/dbms/src/Storages/DeltaMerge/Index/VectorIndexHNSW/Index.cpp b/dbms/src/Storages/DeltaMerge/Index/VectorIndexHNSW/Index.cpp index 5a983824f18..b499fed5ec3 100644 --- a/dbms/src/Storages/DeltaMerge/Index/VectorIndexHNSW/Index.cpp +++ b/dbms/src/Storages/DeltaMerge/Index/VectorIndexHNSW/Index.cpp @@ -155,13 +155,24 @@ std::vector VectorIndexHNSWViewer::search( std::atomic visited_nodes = 0; std::atomic discarded_nodes = 0; + std::atomic has_exception_in_search = false; auto predicate = [&](typename USearchImplType::member_cref_t const & member) { - // Note: We don't increase the thread_local perf, to be compatible with future multi-thread change. - visited_nodes++; - if (!valid_rows[member.key]) - discarded_nodes++; - return valid_rows[member.key]; + // Must catch exceptions in the predicate, because search runs on other threads. + try + { + // Note: We don't increase the thread_local perf, because search runs on other threads. + visited_nodes++; + if (!valid_rows[member.key]) + discarded_nodes++; + return valid_rows[member.key]; + } + catch (...) + { + tryLogCurrentException(__PRETTY_FUNCTION__); + has_exception_in_search = true; + return false; + } }; // TODO: Support efSearch. @@ -169,6 +180,10 @@ std::vector VectorIndexHNSWViewer::search( reinterpret_cast(query_info->ref_vec_f32().data() + sizeof(UInt32)), query_info->top_k(), predicate); + + if (has_exception_in_search) + throw Exception(ErrorCodes::INCORRECT_QUERY, "Exception happened occurred during search"); + std::vector keys(result.size()); result.dump_to(keys.data()); diff --git a/dbms/src/Storages/DeltaMerge/tests/gtest_dm_vector_index.cpp b/dbms/src/Storages/DeltaMerge/tests/gtest_dm_vector_index.cpp index 4ebfa77ed30..67c6c2dce42 100644 --- a/dbms/src/Storages/DeltaMerge/tests/gtest_dm_vector_index.cpp +++ b/dbms/src/Storages/DeltaMerge/tests/gtest_dm_vector_index.cpp @@ -498,6 +498,73 @@ try } CATCH +TEST_P(VectorIndexDMFileTest, OnePackWithDuplicateVectors) +try +{ + auto cols = DMTestEnv::getDefaultColumns(DMTestEnv::PkType::HiddenTiDBRowID, /*add_nullable*/ true); + auto vec_cd = ColumnDefine(vec_column_id, vec_column_name, tests::typeFromString("Array(Float32)")); + vec_cd.vector_index = std::make_shared(TiDB::VectorIndexDefinition{ + .kind = tipb::VectorIndexKind::HNSW, + .dimension = 3, + .distance_metric = tipb::VectorDistanceMetric::L2, + }); + cols->emplace_back(vec_cd); + + ColumnDefines read_cols = *cols; + if (test_only_vec_column) + read_cols = {vec_cd}; + + // Prepare DMFile + { + Block block = DMTestEnv::prepareSimpleWriteBlockWithNullable(0, 5); + block.insert(createVecFloat32Column( + {// + {1.0, 2.0, 3.0}, + {1.0, 2.0, 3.0}, + {0.0, 0.0, 0.0}, + {1.0, 2.0, 3.0}, + {1.0, 2.0, 3.5}}, + vec_cd.name, + vec_cd.id)); + auto stream = std::make_shared(dbContext(), dm_file, *cols); + stream->writePrefix(); + stream->write(block, DMFileBlockOutputStream::BlockProperty{0, 0, 0, 0}); + stream->writeSuffix(); + } + + dm_file = restoreDMFile(); + + { + auto ann_query_info = std::make_shared(); + ann_query_info->set_column_id(vec_cd.id); + ann_query_info->set_distance_metric(tipb::VectorDistanceMetric::L2); + ann_query_info->set_top_k(4); + ann_query_info->set_ref_vec_f32(encodeVectorFloat32({1.0, 2.0, 3.5})); + + DMFileBlockInputStreamBuilder builder(dbContext()); + auto stream = builder.setRSOperator(wrapWithANNQueryInfo(nullptr, ann_query_info)) + .setBitmapFilter(BitmapFilterView(std::make_shared(5, true), 0, 5)) + .build2( + dm_file, + read_cols, + RowKeyRanges{RowKeyRange::newAll(false, 1)}, + std::make_shared()); + + ASSERT_INPUTSTREAM_COLS_UR( + stream, + createColumnNames(), + createColumnData({ + createColumn({0, 1, 3, 4}), + createVecFloat32Column({// + {1.0, 2.0, 3.0}, + {1.0, 2.0, 3.0}, + {1.0, 2.0, 3.0}, + {1.0, 2.0, 3.5}}), + })); + } +} +CATCH + TEST_P(VectorIndexDMFileTest, MultiPacks) try {