Skip to content

Commit

Permalink
feat: Check version of Elasticsearch server and add support for Elast…
Browse files Browse the repository at this point in the history
…icsearch <= 7.5 (#5320)

* Check ES server version + add support for ES <= 7.5

* Adapt comment

* PR feedback
  • Loading branch information
bogdankostic authored Jul 13, 2023
1 parent 63fd63f commit 237d67d
Show file tree
Hide file tree
Showing 6 changed files with 59 additions and 2 deletions.
21 changes: 20 additions & 1 deletion haystack/document_stores/elasticsearch/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,18 @@ def _validate_and_adjust_document_index(self, index_name: str, headers: Optional
mapping["properties"][self.embedding_field] = {"type": "dense_vector", "dims": self.embedding_dim}
self._index_put_mapping(index=index_id, body=mapping, headers=headers)

def _validate_server_version(self, expected_version: int):
"""
Validate that the Elasticsearch server version is compatible with the used ElasticsearchDocumentStore.
"""
if self.server_version[0] != expected_version:
logger.warning(
"This ElasticsearchDocumentStore has been built for Elasticsearch %s, but the detected version of the "
"Elasticsearch server is %s. Unexpected behaviors or errors may occur due to version incompatibility.",
expected_version,
".".join(map(str, self.server_version)),
)

def _get_vector_similarity_query(self, query_emb: np.ndarray, top_k: int):
"""
Generate Elasticsearch query for vector similarity.
Expand All @@ -302,12 +314,19 @@ def _get_vector_similarity_query(self, query_emb: np.ndarray, top_k: int):
if self.skip_missing_embeddings:
script_score_query = {"bool": {"filter": {"bool": {"must": [{"exists": {"field": self.embedding_field}}]}}}}

# Elasticsearch 7.6 introduced a breaking change regarding the vector function signatures:
# https://www.elastic.co/guide/en/elasticsearch/reference/7.6/breaking-changes-7.6.html#_update_to_vector_function_signatures
if self.server_version[0] == 7 and self.server_version[1] < 6:
similarity_script_source = f"{similarity_fn_name}(params.query_vector,doc['{self.embedding_field}']) + 1000"
else:
similarity_script_source = f"{similarity_fn_name}(params.query_vector,'{self.embedding_field}') + 1000"

query = {
"script_score": {
"query": script_score_query,
"script": {
# offset score to ensure a positive range as required by Elasticsearch
"source": f"{similarity_fn_name}(params.query_vector,'{self.embedding_field}') + 1000",
"source": similarity_script_source,
"params": {"query_vector": query_emb.tolist()},
},
}
Expand Down
2 changes: 2 additions & 0 deletions haystack/document_stores/elasticsearch/es7.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,8 @@ def __init__(
batch_size=batch_size,
)

self._validate_server_version(expected_version=7)

def _do_bulk(self, *args, **kwargs):
"""Override the base class method to use the Elasticsearch client"""
return bulk(*args, **kwargs)
Expand Down
2 changes: 2 additions & 0 deletions haystack/document_stores/elasticsearch/es8.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,8 @@ def __init__(
batch_size=batch_size,
)

self._validate_server_version(expected_version=8)

def _do_bulk(self, *args, **kwargs):
"""Override the base class method to use the Elasticsearch client"""
return bulk(*args, **kwargs)
Expand Down
2 changes: 2 additions & 0 deletions haystack/document_stores/search_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,8 @@ def __init__(
raise DocumentStoreError(
f"Invalid value {similarity} for similarity, choose between 'cosine', 'l2' and 'dot_product'"
)
client_info = self.client.info()
self.server_version = tuple(int(num) for num in client_info["version"]["number"].split("."))

self._init_indices(
index=index, label_index=label_index, create_index=create_index, recreate_index=recreate_index
Expand Down
33 changes: 32 additions & 1 deletion test/document_stores/test_elasticsearch.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,13 @@ def mocked_document_store(self):
ElasticsearchDocumentStore equipped with a mocked client
"""

with patch(f"{ElasticsearchDocumentStore.__module__}.ElasticsearchDocumentStore._init_elastic_client"):
with patch(
f"{ElasticsearchDocumentStore.__module__}.ElasticsearchDocumentStore._init_elastic_client"
) as mocked_init_client:
if VERSION[0] == 7:
mocked_init_client().info.return_value = {"version": {"number": "7.17.6"}}
else:
mocked_init_client().info.return_value = {"version": {"number": "8.8.0"}}

class DSMock(ElasticsearchDocumentStore):
# We mock a subclass to avoid messing up the actual class object
Expand Down Expand Up @@ -376,6 +382,31 @@ def test_write_documents_req_for_each_batch(self, mocked_document_store, documen
mocked_document_store.write_documents(documents)
assert mocked_bulk.call_count == 5

@pytest.mark.unit
def test_get_vector_similarity_query(self, mocked_document_store):
"""
Test that the source field of the vector similarity query is correctly formatted for ES 7.6 and above.
We test this to make sure we use the correct syntax for newer ES versions.
"""
vec_sim_query = mocked_document_store._get_vector_similarity_query(np.random.rand(3).astype(np.float32), 10)
assert vec_sim_query["script_score"]["script"]["source"] == "dotProduct(params.query_vector,'embedding') + 1000"

@pytest.mark.unit
def test_get_vector_similarity_query_es_7_5_and_below(self, mocked_document_store):
"""
Test that the source field of the vector similarity query is correctly formatter for ES 7.5 and below.
We test this to make sure we use the correct syntax for ES versions older than 7.6, as the syntax changed
in 7.6.
"""
# Patch server version to be 7.5.0
mocked_document_store.server_version = (7, 5, 0)

vec_sim_query = mocked_document_store._get_vector_similarity_query(np.random.rand(3).astype(np.float32), 10)
assert (
vec_sim_query["script_score"]["script"]["source"]
== "dotProduct(params.query_vector,doc['embedding']) + 1000"
)

# The following tests are overridden only to be able to skip them depending on ES version

@pytest.mark.skipif(VERSION[0] == 8, reason="Elasticsearch 8 uses a different client call")
Expand Down
1 change: 1 addition & 0 deletions test/document_stores/test_opensearch.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ class DSMock(OpenSearchDocumentStore):
opensearch_mock = MagicMock()
opensearch_mock.indices.exists.return_value = True
opensearch_mock.indices.get.return_value = {self.index_name: existing_index}
opensearch_mock.info.return_value = {"version": {"number": "1.3.5"}}
DSMock._init_client = MagicMock()
DSMock._init_client.configure_mock(return_value=opensearch_mock)
dsMock = DSMock()
Expand Down

0 comments on commit 237d67d

Please sign in to comment.