Skip to content

Commit

Permalink
fix: Support urls
Browse files Browse the repository at this point in the history
  • Loading branch information
Josephasafg committed Jun 13, 2024
1 parent 55aa272 commit 3fc7265
Show file tree
Hide file tree
Showing 5 changed files with 45 additions and 2 deletions.
10 changes: 9 additions & 1 deletion ai21/clients/studio/ai21_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def __init__(
env_config: _AI21EnvConfig = AI21EnvConfig,
**kwargs,
):
base_url = api_host or f"{env_config.api_host}/studio/v1"
base_url = self._create_url(api_host or env_config.api_host)

self._http_client = AI21HTTPClient(
api_key=api_key or env_config.api_key,
Expand All @@ -70,6 +70,14 @@ def __init__(
self.segmentation = StudioSegmentation(self._http_client)
self.beta = Beta(self._http_client)

def _create_url(self, base_url: str) -> str:
allowed_urls = ["https://api-stage.ai21.com", "https://api.ai21.com"]

if base_url in allowed_urls:
return f"{base_url}/studio/v1"

return base_url

def count_tokens(self, text: str, tokenizer_name: str = PreTrainedTokenizers.J2_TOKENIZER) -> int:
warnings.warn(
"Please use the global get_tokenizer() method directly instead of the AI21Client().count_tokens() method.",
Expand Down
10 changes: 9 additions & 1 deletion ai21/clients/studio/async_ai21_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def __init__(
env_config: _AI21EnvConfig = AI21EnvConfig,
**kwargs,
):
base_url = api_host or f"{env_config.api_host}/studio/v1"
base_url = self._create_url(api_host or env_config.api_host)

self._http_client = AsyncAI21HTTPClient(
api_key=api_key or env_config.api_key,
Expand All @@ -63,3 +63,11 @@ def __init__(
self.library = AsyncStudioLibrary(self._http_client)
self.segmentation = AsyncStudioSegmentation(self._http_client)
self.beta = AsyncBeta(self._http_client)

def _create_url(self, base_url: str) -> str:
allowed_urls = ["https://api-stage.ai21.com", "https://api.ai21.com"]

if base_url in allowed_urls:
return f"{base_url}/studio/v1"

return base_url
1 change: 1 addition & 0 deletions ai21/constants.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
DEFAULT_API_VERSION = "v1"
STUDIO_HOST = "https://api.ai21.com"
"https://api-stage.ai21.com"
14 changes: 14 additions & 0 deletions tests/unittests/clients/studio/test_ai21_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,17 @@ def test_async_ai21_client__when_pass_api_host__should_leave_as_is():
def test_async_ai21_client__when_not_pass_api_host__should_add_suffix():
client = AsyncAI21Client()
assert client._http_client._base_url == f"{AI21EnvConfig.api_host}/studio/v1"


@pytest.mark.asyncio
def test_async_ai21_client__when_pass_ai21_api_host__should_add_suffix():
ai21_url = "https://api.ai21.com"
client = AsyncAI21Client(api_host=ai21_url)
assert client._http_client._base_url == f"{ai21_url}/studio/v1"


@pytest.mark.asyncio
def test_async_ai21_client__when_pass_ai21_with_suffix__should_not_modify():
ai21_url = "https://api.ai21.com/studio/v1"
client = AsyncAI21Client(api_host=ai21_url)
assert client._http_client._base_url == ai21_url
12 changes: 12 additions & 0 deletions tests/unittests/clients/studio/test_async_ai21_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,15 @@ def test_ai21_client__when_pass_api_host__should_leave_as_is():
def test_ai21_client__when_not_pass_api_host__should_add_suffix():
client = AI21Client()
assert client._http_client._base_url == f"{AI21EnvConfig.api_host}/studio/v1"


def test_ai21_client__when_pass_ai21_api_host__should_add_suffix():
ai21_url = "https://api.ai21.com"
client = AI21Client(api_host=ai21_url)
assert client._http_client._base_url == f"{ai21_url}/studio/v1"


def test_ai21_client__when_pass_ai21_with_suffix__should_not_modify():
ai21_url = "https://api.ai21.com/studio/v1"
client = AI21Client(api_host=ai21_url)
assert client._http_client._base_url == ai21_url

0 comments on commit 3fc7265

Please sign in to comment.