Skip to content

Commit

Permalink
compute: Use SimSIMD for vectors (pingcap#221)
Browse files Browse the repository at this point in the history
Signed-off-by: Wish <[email protected]>
  • Loading branch information
breezewish authored and Lloyd-Pottiger committed Aug 26, 2024
1 parent 3f15689 commit feda156
Show file tree
Hide file tree
Showing 10 changed files with 47 additions and 47 deletions.
3 changes: 3 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
Expand Up @@ -149,3 +149,6 @@
[submodule "contrib/usearch"]
path = contrib/usearch
url = https://github.com/unum-cloud/usearch.git
[submodule "contrib/simsimd"]
path = contrib/simsimd
url = https://github.com/ashvardanian/SimSIMD
2 changes: 2 additions & 0 deletions contrib/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -200,3 +200,5 @@ add_subdirectory(simdjson)
add_subdirectory(fastpforlib)

add_subdirectory(usearch-cmake)

add_subdirectory(simsimd-cmake)
1 change: 1 addition & 0 deletions contrib/simsimd
Submodule simsimd added at 3e2193
13 changes: 13 additions & 0 deletions contrib/simsimd-cmake/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
set(SIMSIMD_PROJECT_DIR "${TiFlash_SOURCE_DIR}/contrib/simsimd")
set(SIMSIMD_SOURCE_DIR "${SIMSIMD_PROJECT_DIR}/include")

add_library(_simsimd INTERFACE)

if (NOT EXISTS "${SIMSIMD_SOURCE_DIR}/simsimd/simsimd.h")
message (FATAL_ERROR "submodule contrib/simsimd not found")
endif()

target_include_directories(_simsimd SYSTEM INTERFACE
${SIMSIMD_SOURCE_DIR})

add_library(tiflash_contrib::simsimd ALIAS _simsimd)
2 changes: 1 addition & 1 deletion contrib/usearch-cmake/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ if (NOT EXISTS "${USEARCH_SOURCE_DIR}/usearch/index.hpp")
endif ()

target_include_directories(_usearch SYSTEM INTERFACE
${USEARCH_PROJECT_DIR}/simsimd/include
# ${USEARCH_PROJECT_DIR}/simsimd/include # Use our simsimd
${USEARCH_PROJECT_DIR}/fp16/include
${USEARCH_SOURCE_DIR})

Expand Down
1 change: 1 addition & 0 deletions dbms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,7 @@ target_link_libraries (dbms
${BTRIE_LIBRARIES}
absl::synchronization
tiflash_contrib::usearch
tiflash_contrib::simsimd
tiflash_contrib::aws_s3

etcdpb
Expand Down
10 changes: 9 additions & 1 deletion dbms/src/Functions/tests/gtest_vector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,15 @@ TEST_F(Vector, CosineDistance)
try
{
ASSERT_COLUMN_EQ(
createColumn<Nullable<Float64>>({0.0, std::nullopt, 0.0, 1.0, 2.0, 0.0, 2.0, std::nullopt}),
createColumn<Nullable<Float64>>(
{0.004130363464355469,
1.0, // CosDistance to (0,0) cannot be calculated, clapped to 1.0
0.00572967529296875,
1.0,
1.9942703247070312,
0.00022123707458376884,
1.9997787475585938,
std::nullopt}),
executeFunction(
"vecCosineDistance",
createColumn<Array>(
Expand Down
9 changes: 4 additions & 5 deletions dbms/src/Storages/DeltaMerge/Index/VectorIndexHNSW/Index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -192,16 +192,15 @@ std::vector<VectorIndexBuilder::Key> VectorIndexHNSWViewer::search(
std::atomic<size_t> discarded_nodes = 0;
std::atomic<bool> has_exception_in_search = false;

// The non-valid rows should be discarded by this lambda
auto predicate = [&](typename USearchImplType::member_cref_t const & member) {
auto predicate = [&](const Key & key) {
// Must catch exceptions in the predicate, because search runs on other threads.
try
{
// Note: We don't increase the thread_local perf, because search runs on other threads.
visited_nodes++;
if (!valid_rows[member.key])
if (!valid_rows[key])
discarded_nodes++;
return valid_rows[member.key];
return valid_rows[key];
}
catch (...)
{
Expand All @@ -215,7 +214,7 @@ std::vector<VectorIndexBuilder::Key> VectorIndexHNSWViewer::search(
SCOPE_EXIT({ GET_METRIC(tiflash_vector_index_duration, type_search).Observe(w.elapsedSeconds()); });

// TODO(vector-index): Support efSearch.
auto result = index.search( //
auto result = index.filtered_search( //
reinterpret_cast<const Float32 *>(query_info->ref_vec_f32().data() + sizeof(UInt32)),
query_info->top_k(),
predicate);
Expand Down
3 changes: 2 additions & 1 deletion dbms/src/Storages/DeltaMerge/Index/VectorIndexHNSW/Index.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@

#include <Storages/DeltaMerge/File/dtpb/dmfile.pb.h>
#include <Storages/DeltaMerge/Index/VectorIndex.h>
#include <Storages/DeltaMerge/Index/VectorIndexHNSW/usearch_index_dense.h>

#include <usearch/index_dense.hpp>

namespace DB::DM
{
Expand Down
50 changes: 11 additions & 39 deletions dbms/src/TiDB/Decode/Vector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@
#include <IO/WriteHelpers.h>
#include <TiDB/Decode/Vector.h>

#define SIMSIMD_NATIVE_F16 0
#define SIMSIMD_NATIVE_BF16 0
#include <simsimd/simsimd.h>

#include <compare>

namespace DB
Expand Down Expand Up @@ -50,15 +54,8 @@ Float64 VectorFloat32Ref::l2SquaredDistance(VectorFloat32Ref b) const
{
checkDims(b);

Float32 distance = 0.0;
Float32 diff;

for (size_t i = 0, i_max = size(); i < i_max; ++i)
{
// Hope this can be vectorized.
diff = elements[i] - b[i];
distance += diff * diff;
}
simsimd_distance_t distance;
simsimd_l2sq_f32(elements, b.elements, elements_n, &distance);

return distance;
}
Expand All @@ -67,13 +64,8 @@ Float64 VectorFloat32Ref::innerProduct(VectorFloat32Ref b) const
{
checkDims(b);

Float32 distance = 0.0;

for (size_t i = 0, i_max = size(); i < i_max; ++i)
{
// Hope this can be vectorized.
distance += elements[i] * b[i];
}
simsimd_distance_t distance;
simsimd_dot_f32(elements, b.elements, elements_n, &distance);

return distance;
}
Expand All @@ -82,30 +74,10 @@ Float64 VectorFloat32Ref::cosineDistance(VectorFloat32Ref b) const
{
checkDims(b);

Float32 distance = 0.0;
Float32 norma = 0.0;
Float32 normb = 0.0;

for (size_t i = 0, i_max = size(); i < i_max; ++i)
{
// Hope this can be vectorized.
distance += elements[i] * b[i];
norma += elements[i] * elements[i];
normb += b[i] * b[i];
}

Float64 similarity
= static_cast<Float64>(distance) / std::sqrt(static_cast<Float64>(norma) * static_cast<Float64>(normb));

if (std::isnan(similarity))
{
// When norma or normb is zero, distance is zero, and similarity is NaN.
// similarity can not be Inf in this case.
return std::nan("");
}
simsimd_distance_t distance;
simsimd_cos_f32(elements, b.elements, elements_n, &distance);

similarity = std::clamp(similarity, -1.0, 1.0);
return 1.0 - similarity;
return distance;
}

Float64 VectorFloat32Ref::l1Distance(VectorFloat32Ref b) const
Expand Down

0 comments on commit feda156

Please sign in to comment.