-
Notifications
You must be signed in to change notification settings - Fork 201
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: VoyageAI encoder #255
base: main
Are you sure you want to change the base?
Changes from all commits
2e4a13d
1229f07
bc5b975
759c56c
7a5470c
c2f74d8
4832c27
2f05ea8
c4fe976
10b9e54
d5d997d
966b23a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
import os | ||
from time import sleep | ||
from typing import Any, List, Optional | ||
|
||
from pydantic.v1 import PrivateAttr | ||
|
||
from semantic_router.encoders import BaseEncoder | ||
from semantic_router.utils.defaults import EncoderDefault | ||
from semantic_router.utils.logger import logger | ||
|
||
|
||
class VoyageAIEncoder(BaseEncoder): | ||
_client: Any = PrivateAttr() | ||
type: str = "voyageai" | ||
|
||
def __init__( | ||
self, | ||
name: Optional[str] = None, | ||
voyage_api_key: Optional[str] = None, | ||
score_threshold: float = 0.82, | ||
): | ||
if name is None: | ||
name = EncoderDefault.VOYAGE.value["embedding_model"] | ||
super().__init__(name=name, score_threshold=score_threshold) | ||
self._client = self._initialize_client(api_key=voyage_api_key) | ||
|
||
def _initialize_client(self, api_key: Optional[str] = None): | ||
try: | ||
import voyageai | ||
except ImportError: | ||
raise ImportError( | ||
"Please install VoyageAI to use VoyageAIEncoder. " | ||
"You can install it with: " | ||
"`pip install 'semantic-router[voyageai]'`" | ||
) | ||
|
||
api_key = api_key or os.getenv("VOYAGEAI_API_KEY") | ||
if api_key is None: | ||
raise ValueError("VoyageAI API key not provided") | ||
try: | ||
client = voyageai.Client(api_key=api_key) | ||
except Exception as e: | ||
raise ValueError(f"Unable to connect to VoyageAI {e.args}: {e}") from e | ||
return client | ||
|
||
def __call__(self, docs: List[str]) -> List[List[float]]: | ||
if self._client == PrivateAttr(): | ||
raise ValueError("VoyageAI client is not initialized.") | ||
embeds = None | ||
error_message = "" | ||
|
||
# Exponential backoff | ||
for j in range(1, 7): | ||
try: | ||
embeds = self._client.embed( | ||
texts=docs, | ||
model=self.name, | ||
input_type="query", # query or document | ||
) | ||
if embeds.embeddings: | ||
break | ||
else: | ||
sleep(2**j) | ||
logger.warning(f"Retrying in {2**j} seconds...") | ||
|
||
except Exception as e: | ||
logger.error(f"VoyageAI API call failed. Error: {error_message}") | ||
raise ValueError(f"VoyageAI API call failed. Error: {e}") from e | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The last test problems seem to be due to the handling of exceptions in the exponential backoff, currently it will stop on the first exception and not retry at all due to the exception being raised straight away. To get it to continue for the specified number of retries as the test requires you could just add a line to continue the retries before the raise which should pass those other tests except Exception as e:
logger.error(f"VoyageAI API call failed. Error: {error_message}")
if j < 6:
sleep(2 ** j)
continue
raise ValueError(f"VoyageAI API call failed. Error: {e}") from e |
||
|
||
if not embeds or not embeds.embeddings: | ||
raise ValueError("VoyageAI API call failed. Error: No embeddings found.") | ||
|
||
return embeds.embeddings |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,100 @@ | ||
from unittest.mock import patch | ||
|
||
import pytest | ||
|
||
from pydantic.v1 import PrivateAttr | ||
|
||
from semantic_router.encoders import VoyageAIEncoder | ||
|
||
|
||
@pytest.fixture | ||
def voyageai_encoder(mocker): | ||
mocker.patch("voyageai.Client") | ||
return VoyageAIEncoder(voyage_api_key="test_api_key") | ||
|
||
|
||
class TestVoyageAIEncoder: | ||
def test_voyageai_encoder_import_error(self): | ||
with patch.dict("sys.modules", {"voyageai": None}): | ||
with pytest.raises(ImportError) as error: | ||
VoyageAIEncoder() | ||
|
||
assert "pip install 'semantic-router[voyageai]'" in str(error.value) | ||
|
||
def test_voyageai_encoder_init_success(self, mocker): | ||
side_effect = ["fake-model-name", "fake-api-key"] | ||
mocker.patch("os.getenv", side_effect=side_effect) | ||
encoder = VoyageAIEncoder() | ||
assert encoder._client is not PrivateAttr() | ||
|
||
def test_voyageai_encoder_init_no_api_key(self, mocker): | ||
mocker.patch("os.getenv", return_value=None) | ||
with pytest.raises(ValueError) as _: | ||
VoyageAIEncoder() | ||
|
||
def test_voyageai_encoder_call_uninitialized_client(self, voyageai_encoder): | ||
voyageai_encoder._client = PrivateAttr() | ||
with pytest.raises(ValueError) as e: | ||
voyageai_encoder(["test document"]) | ||
assert "VoyageAI client is not initialized." in str(e.value) | ||
|
||
def test_voyageai_encoder_init_exception(self, mocker): | ||
mocker.patch("os.getenv", return_value="fake-api-key") | ||
mocker.patch("voyageai.Client", side_effect=Exception("Initialization error")) | ||
with pytest.raises(ValueError) as e: | ||
VoyageAIEncoder() | ||
assert ( | ||
"VOYAGE API client failed to initialize. Error: Initialization error" | ||
in str(e.value) | ||
) | ||
|
||
def test_voyageai_encoder_call_success(self, voyageai_encoder, mocker): | ||
mock_response = mocker.Mock() | ||
mock_response.embeddings = [[0.1, 0.2]] | ||
|
||
mocker.patch("os.getenv", return_value="fake-api-key", autospec=True) | ||
mocker.patch("time.sleep", return_value=None) | ||
|
||
mocker.patch.object( | ||
voyageai_encoder._client, "embed", return_value=mock_response | ||
) | ||
embeddings = voyageai_encoder(["test document"]) | ||
assert embeddings == [[0.1, 0.2]] | ||
|
||
def test_voyageai_encoder_call_with_retries(self, voyageai_encoder, mocker): | ||
error = Exception("Network error") | ||
mocker.patch("os.getenv", return_value="fake-api-key") | ||
mocker.patch("time.sleep", return_value=None) | ||
mocker.patch.object( | ||
voyageai_encoder._client, | ||
"embed", | ||
side_effect=[error, error, mocker.Mock(embeddings=[[0.1, 0.2]])], | ||
) | ||
embeddings = voyageai_encoder(["test document"]) | ||
assert embeddings == [[0.1, 0.2]] | ||
|
||
def test_voyageai_encoder_call_failure_non_voyage_error( | ||
self, voyageai_encoder, mocker | ||
): | ||
mocker.patch("os.getenv", return_value="fake-api-key") | ||
mocker.patch("time.sleep", return_value=None) | ||
mocker.patch.object( | ||
voyageai_encoder._client, | ||
"embed", | ||
side_effect=Exception("General error"), | ||
) | ||
with pytest.raises(ValueError) as e: | ||
voyageai_encoder(["test document"]) | ||
assert "VoyageAI API call failed. Error: General error" in str(e.value) | ||
|
||
def test_voyageai_encoder_call_successful_retry(self, voyageai_encoder, mocker): | ||
mock_response = mocker.Mock() | ||
mock_response.embeddings = [[0.1, 0.2]] | ||
|
||
mocker.patch("os.getenv", return_value="fake-api-key") | ||
mocker.patch("time.sleep", return_value=None) | ||
|
||
responses = [Exception("Temporary error"), mock_response] | ||
mocker.patch.object(voyageai_encoder._client, "embed", side_effect=responses) | ||
embeddings = voyageai_encoder(["test document"]) | ||
assert embeddings == [[0.1, 0.2]] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just to help with the test problems, the test is looking to assert the error message outputs as
"VOYAGE API client failed to initialize. Error: Initialization error"
but because e.args returns a tuple the output from the failure is actually"Unable to connect to VoyageAI ('Initialization error',): Initialization error"
.To pass the assertion you could update your test assertion to match the args output or you could update the actual error message to
ValueError(f"VOYAGE API client failed to initialize. Error: {e}")
which should pass