Skip to content

Commit

Permalink
[ENH] Make Rust/C++ FFI error handling robust (#2667)
Browse files Browse the repository at this point in the history
## Description of changes

*Summarize the changes made by this PR.*
 - Improvements & Bug fixes
- Introduces a mechanism to propagate errors from C++ into rust by
catching errors in the bindings and then populating a thread_local
variable.
	 - Build the c++ code automatically on change
	 - Propagate new errors up through the codebase
	 - Coalesce shared functionality in rust hnsw tests
- the c++ bindings were creating errors, but not throwing them, this
path is purely defensive so it was never excercised, this fixes that.
 - New functionality
	 - None

## Test plan
*How are these changes tested?*
Added a new test to test add() errors
- [x] Tests pass locally with `pytest` for python, `yarn test` for js,
`cargo test` for rust

## Documentation Changes
None
  • Loading branch information
HammadB committed Aug 15, 2024
1 parent c6aa0c8 commit fb1201a
Show file tree
Hide file tree
Showing 7 changed files with 367 additions and 122 deletions.
202 changes: 174 additions & 28 deletions rust/index/bindings.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
// Assumes that chroma-hnswlib is checked out at the same level as chroma
#include "../../../hnswlib/hnswlib/hnswlib.h"
#include <thread>

class AllowAndDisallowListFilterFunctor : public hnswlib::BaseFilterFunctor
{
Expand All @@ -23,6 +24,12 @@ class AllowAndDisallowListFilterFunctor : public hnswlib::BaseFilterFunctor
}
};

// thread-local for the last error message, callers are expected to check this
// the empty string represents no error
// this is currently shared across all instances of Index, but that's fine for now
// since it is thread-local
thread_local std::string last_error;

template <typename dist_t, typename data_t = float>
class Index
{
Expand Down Expand Up @@ -57,6 +64,7 @@ class Index
}
appr_alg = NULL;
index_inited = false;
last_error.clear();
}

~Index()
Expand All @@ -72,7 +80,7 @@ class Index
{
if (index_inited)
{
std::runtime_error("Index already inited");
throw std::runtime_error("Index already inited");
}
appr_alg = new hnswlib::HierarchicalNSW<dist_t>(l2space, max_elements, M, ef_construction, random_seed, allow_replace_deleted, normalize, is_persistent_index, persistence_location);
appr_alg->ef_ = 10; // This is a default value for ef_
Expand All @@ -83,7 +91,7 @@ class Index
{
if (index_inited)
{
std::runtime_error("Index already inited");
throw std::runtime_error("Index already inited");
}
appr_alg = new hnswlib::HierarchicalNSW<dist_t>(l2space, path_to_index, false, 0, allow_replace_deleted, normalize, is_persistent_index);
index_inited = true;
Expand All @@ -93,7 +101,7 @@ class Index
{
if (!index_inited)
{
std::runtime_error("Index not inited");
throw std::runtime_error("Index not inited");
}
appr_alg->persistDirty();
}
Expand All @@ -102,16 +110,17 @@ class Index
{
if (!index_inited)
{
std::runtime_error("Index not inited");
throw std::runtime_error("Index not inited");
}

appr_alg->addPoint(data, id);
}

void get_item(const hnswlib::labeltype id, data_t *data)
{
if (!index_inited)
{
std::runtime_error("Index not inited");
throw std::runtime_error("Inde not inited");
}
std::vector<data_t> ret_data = appr_alg->template getDataByLabel<data_t>(id); // This checks if id is deleted
for (int i = 0; i < dim; i++)
Expand All @@ -120,21 +129,20 @@ class Index
}
}

int mark_deleted(const hnswlib::labeltype id)
void mark_deleted(const hnswlib::labeltype id)
{
if (!index_inited)
{
std::runtime_error("Index not inited");
throw std::runtime_error("Index not inited");
}
appr_alg->markDelete(id);
return 0;
}

size_t knn_query(const data_t *query_vector, const size_t k, hnswlib::labeltype *ids, data_t *distance, const hnswlib::labeltype *allowed_ids, const size_t allowed_id_length, const hnswlib::labeltype *disallowed_ids, const size_t disallowed_id_length)
{
if (!index_inited)
{
std::runtime_error("Index not inited");
throw std::runtime_error("Index not inited");
}

std::unordered_set<hnswlib::labeltype> allow_list;
Expand All @@ -155,11 +163,6 @@ class Index
}
AllowAndDisallowListFilterFunctor filter = AllowAndDisallowListFilterFunctor(allow_list, disallow_list);
std::priority_queue<std::pair<dist_t, hnswlib::labeltype>> res = appr_alg->searchKnn(query_vector, k, &filter);
if (res.size() < k)
{
// TODO: This is ok and we should return < K results, but for maintining compatibility with the old API we throw an error for now
std::runtime_error("Not enough results");
}
int total_results = std::min(res.size(), k);
for (int i = total_results - 1; i >= 0; i--)
{
Expand All @@ -175,7 +178,7 @@ class Index
{
if (!index_inited)
{
std::runtime_error("Index not inited");
throw std::runtime_error("Index not inited");
}
return appr_alg->ef_;
}
Expand All @@ -184,7 +187,7 @@ class Index
{
if (!index_inited)
{
std::runtime_error("Index not inited");
throw std::runtime_error("Index not inited");
}
appr_alg->ef_ = ef;
}
Expand All @@ -193,76 +196,219 @@ class Index
{
if (!index_inited)
{
std::runtime_error("Index not inited");
throw std::runtime_error("Index not inited");
}
appr_alg->resizeIndex(new_size);
}
};

// All these methods except for len() and capacity() can "throw" a std::exception
// and populate the last_error thread-local variable. This is how we communicate
// errors across the FFI boundary - the C++ layer will catch all exceptions and
// set the last_error variable, which the Rust layer can then check.
// Comments referring to "throwing" exceptions in this block refer to this mechanism.
extern "C"
{

// Can throw std::exception
Index<float> *create_index(const char *space_name, const int dim)
{
Index<float> *index;
try
{
index = new Index<float>(space_name, dim);
}
catch (std::exception &e)
{
last_error = e.what();
return nullptr;
}
last_error.clear();
return new Index<float>(space_name, dim);
}

// Can throw std::exception
void init_index(Index<float> *index, const size_t max_elements, const size_t M, const size_t ef_construction, const size_t random_seed, const bool allow_replace_deleted, const bool is_persistent_index, const char *persistence_location)
{
index->init_index(max_elements, M, ef_construction, random_seed, allow_replace_deleted, is_persistent_index, persistence_location);
try
{
index->init_index(max_elements, M, ef_construction, random_seed, allow_replace_deleted, is_persistent_index, persistence_location);
}
catch (std::exception &e)
{
last_error = e.what();
return;
}
last_error.clear();
}

// Can throw std::exception
void load_index(Index<float> *index, const char *path_to_index, const bool allow_replace_deleted, const bool is_persistent_index)
{
index->load_index(path_to_index, allow_replace_deleted, is_persistent_index);
try
{
index->load_index(path_to_index, allow_replace_deleted, is_persistent_index);
}
catch (std::exception &e)
{
last_error = e.what();
return;
}
last_error.clear();
}

// Can throw std::exception
void persist_dirty(Index<float> *index)
{
index->persist_dirty();
try
{
index->persist_dirty();
}
catch (std::exception &e)
{
last_error = e.what();
return;
}
last_error.clear();
}

// Can throw std::exception
void add_item(Index<float> *index, const float *data, const hnswlib::labeltype id, const bool replace_deleted)
{
index->add_item(data, id);
try
{
index->add_item(data, id, replace_deleted);
}
catch (std::exception &e)
{
last_error = e.what();
return;
}
last_error.clear();
}

// Can throw std::exception
void get_item(Index<float> *index, const hnswlib::labeltype id, float *data)
{
index->get_item(id, data);
try
{
index->get_item(id, data);
}
catch (std::exception &e)
{
last_error = e.what();
return;
}
last_error.clear();
}

int mark_deleted(Index<float> *index, const hnswlib::labeltype id)
// Can throw std::exception
void mark_deleted(Index<float> *index, const hnswlib::labeltype id)
{
return index->mark_deleted(id);
try
{
index->mark_deleted(id);
}
catch (std::exception &e)
{
last_error = e.what();
return;
}
last_error.clear();
}

// Can throw std::exception
size_t knn_query(Index<float> *index, const float *query_vector, const size_t k, hnswlib::labeltype *ids, float *distance, const hnswlib::labeltype *allowed_ids, const size_t allowed_id_length, const hnswlib::labeltype *disallowed_ids, const size_t disallowed_id_length)
{
return index->knn_query(query_vector, k, ids, distance, allowed_ids, allowed_id_length, disallowed_ids, disallowed_id_length);
size_t result;
try
{
result = index->knn_query(query_vector, k, ids, distance, allowed_ids, allowed_id_length, disallowed_ids, disallowed_id_length);
}
catch (std::exception &e)
{
last_error = e.what();
return 0;
}
last_error.clear();
return result;
}

// Can throw std::exception
int get_ef(Index<float> *index)
{
return index->appr_alg->ef_;
int ret;
try
{
ret = index->get_ef();
}
catch (std::exception &e)
{
last_error = e.what();
return -1;
}
last_error.clear();
return ret;
}

// Can throw std::exception
void set_ef(Index<float> *index, const size_t ef)
{
index->set_ef(ef);
try
{
index->set_ef(ef);
}
catch (std::exception &e)
{
last_error = e.what();
return;
}
last_error.clear();
}

// Can not throw std::exception
int len(Index<float> *index)
{
if (!index->index_inited)
{
return 0;
}

return index->appr_alg->getCurrentElementCount() - index->appr_alg->getDeletedCount();
}

// Can not throw std::exception
size_t capacity(Index<float> *index)
{
if (!index->index_inited)
{
return 0;
}

return index->appr_alg->max_elements_;
}

// Can throw std::exception
void resize_index(Index<float> *index, size_t new_size)
{
index->resize_index(new_size);
try
{
index->resize_index(new_size);
}
catch (std::exception &e)
{
last_error = e.what();
return;
}
last_error.clear();
}

const char *get_last_error(Index<float> *index)
{
if (last_error.empty())
{
return nullptr;
}
return last_error.c_str();
}
}
2 changes: 2 additions & 0 deletions rust/index/build.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
fn main() -> Result<(), Box<dyn std::error::Error>> {
// Tell cargo to rerun this build script if the bindings change.
println!("cargo:rerun-if-changed=bindings.cpp");
// Compile the hnswlib bindings.
cc::Build::new()
.cpp(true)
Expand Down
Loading

0 comments on commit fb1201a

Please sign in to comment.