Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] Make sure Client parameters are strings #1577

Merged
merged 5 commits into from
Feb 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 29 additions & 3 deletions chromadb/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,10 @@ def EphemeralClient(
settings = Settings()
settings.is_persistent = False

# Make sure paramaters are the correct types -- users can pass anything.
tenant = str(tenant)
database = str(database)

return ClientCreator(settings=settings, tenant=tenant, database=database)


Expand All @@ -135,12 +139,16 @@ def PersistentClient(
settings.persist_directory = path
settings.is_persistent = True

# Make sure paramaters are the correct types -- users can pass anything.
tenant = str(tenant)
database = str(database)

return ClientCreator(tenant=tenant, database=database, settings=settings)


def HttpClient(
host: str = "localhost",
port: str = "8000",
port: int = 8000,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what will happen if the user currently is passing a string? just a type error, right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Users can pass whatever they want -- these are just type hints which users can choose to statically check or not. If a user is currently passing a string we'll turn it into an int with the port = int(port) line below. This will not break anyone.

ssl: bool = False,
headers: Optional[Dict[str, str]] = None,
settings: Optional[Settings] = None,
Expand All @@ -165,6 +173,13 @@ def HttpClient(
if settings is None:
settings = Settings()

# Make sure paramaters are the correct types -- users can pass anything.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should the int conversions have a try/catch in case of failure?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we're fine erroring if someone passes in a port that's not an int (same applies to other types). All we would do is throw another error.

host = str(host)
port = int(port)
ssl = bool(ssl)
tenant = str(tenant)
database = str(database)

settings.chroma_api_impl = "chromadb.api.fastapi.FastAPI"
if settings.chroma_server_host and settings.chroma_server_host != host:
raise ValueError(
Expand All @@ -189,7 +204,7 @@ def CloudClient(
settings: Optional[Settings] = None,
*, # Following arguments are keyword-only, intended for testing only.
cloud_host: str = "api.trychroma.com",
cloud_port: str = "8000",
cloud_port: int = 8000,
enable_ssl: bool = True,
) -> ClientAPI:
"""
Expand Down Expand Up @@ -217,6 +232,14 @@ def CloudClient(
if settings is None:
settings = Settings()

# Make sure paramaters are the correct types -- users can pass anything.
tenant = str(tenant)
database = str(database)
api_key = str(api_key)
cloud_host = str(cloud_host)
cloud_port = int(cloud_port)
enable_ssl = bool(enable_ssl)

settings.chroma_api_impl = "chromadb.api.fastapi.FastAPI"
settings.chroma_server_host = cloud_host
settings.chroma_server_http_port = cloud_port
Expand All @@ -242,9 +265,12 @@ def Client(

tenant: The tenant to use for this client. Defaults to the default tenant.
database: The database to use for this client. Defaults to the default database.

"""

# Make sure paramaters are the correct types -- users can pass anything.
tenant = str(tenant)
database = str(database)

return ClientCreator(tenant=tenant, database=database, settings=settings)


Expand Down
2 changes: 1 addition & 1 deletion chromadb/api/fastapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def __init__(self, system: System):

self._api_url = FastAPI.resolve_url(
chroma_server_host=str(system.settings.chroma_server_host),
chroma_server_http_port=int(str(system.settings.chroma_server_http_port)),
chroma_server_http_port=system.settings.chroma_server_http_port,
chroma_server_ssl_enabled=system.settings.chroma_server_ssl_enabled,
default_api_path=system.settings.chroma_server_api_default_path,
)
Expand Down
8 changes: 4 additions & 4 deletions chromadb/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,10 +118,10 @@ class Settings(BaseSettings): # type: ignore

chroma_server_host: Optional[str] = None
chroma_server_headers: Optional[Dict[str, str]] = None
chroma_server_http_port: Optional[str] = None
chroma_server_http_port: Optional[int] = None
chroma_server_ssl_enabled: Optional[bool] = False
chroma_server_api_default_path: Optional[str] = "/api/v1"
chroma_server_grpc_port: Optional[str] = None
chroma_server_grpc_port: Optional[int] = None
# eg ["http://localhost:3000"]
chroma_server_cors_allow_origins: List[str] = []

Expand All @@ -134,8 +134,8 @@ def empty_str_to_none(cls, v: str) -> Optional[str]:
chroma_server_nofile: Optional[int] = None

pulsar_broker_url: Optional[str] = None
pulsar_admin_port: Optional[str] = "8080"
pulsar_broker_port: Optional[str] = "6650"
pulsar_admin_port: Optional[int] = 8080
pulsar_broker_port: Optional[int] = 6650

chroma_server_auth_provider: Optional[str] = None

Expand Down
6 changes: 3 additions & 3 deletions chromadb/test/client/test_cloud_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def mock_cloud_server(valid_token: str) -> Generator[System, None, None]:
settings = Settings(
chroma_api_impl="chromadb.api.fastapi.FastAPI",
chroma_server_host=TEST_CLOUD_HOST,
chroma_server_http_port=str(port),
chroma_server_http_port=port,
chroma_client_auth_provider="chromadb.auth.token.TokenAuthClientProvider",
chroma_client_auth_credentials=valid_token,
chroma_client_auth_token_transport_header=TOKEN_TRANSPORT_HEADER,
Expand All @@ -82,7 +82,7 @@ def test_valid_key(mock_cloud_server: System, valid_token: str) -> None:
database=DEFAULT_DATABASE,
api_key=valid_token,
cloud_host=TEST_CLOUD_HOST,
cloud_port=mock_cloud_server.settings.chroma_server_http_port, # type: ignore
cloud_port=mock_cloud_server.settings.chroma_server_http_port or 8000,
enable_ssl=False,
)

Expand All @@ -98,7 +98,7 @@ def test_invalid_key(mock_cloud_server: System, valid_token: str) -> None:
database=DEFAULT_DATABASE,
api_key=invalid_token,
cloud_host=TEST_CLOUD_HOST,
cloud_port=mock_cloud_server.settings.chroma_server_http_port, # type: ignore
cloud_port=mock_cloud_server.settings.chroma_server_http_port or 8000,
enable_ssl=False,
)
client.heartbeat()
4 changes: 2 additions & 2 deletions chromadb/test/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ def _fastapi_fixture(
settings = Settings(
chroma_api_impl="chromadb.api.fastapi.FastAPI",
chroma_server_host="localhost",
chroma_server_http_port=str(port),
chroma_server_http_port=port,
allow_reset=True,
chroma_client_auth_provider=chroma_client_auth_provider,
chroma_client_auth_credentials=chroma_client_auth_credentials,
Expand Down Expand Up @@ -234,7 +234,7 @@ def fastapi_persistent() -> Generator[System, None, None]:
def basic_http_client() -> Generator[System, None, None]:
settings = Settings(
chroma_api_impl="chromadb.api.fastapi.FastAPI",
chroma_server_http_port="8000",
chroma_server_http_port=8000,
allow_reset=True,
)
system = System(settings)
Expand Down
6 changes: 3 additions & 3 deletions chromadb/test/test_chroma.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def test_fastapi(self, mock: Mock) -> None:
chroma_api_impl="chromadb.api.fastapi.FastAPI",
persist_directory="./foo",
chroma_server_host="foo",
chroma_server_http_port="80",
chroma_server_http_port=80,
)
)
assert mock.called
Expand All @@ -78,7 +78,7 @@ def test_settings_pass_to_fastapi(self, mock: Mock) -> None:
settings = chromadb.config.Settings(
chroma_api_impl="chromadb.api.fastapi.FastAPI",
chroma_server_host="foo",
chroma_server_http_port="80",
chroma_server_http_port=80,
chroma_server_headers={"foo": "bar"},
)
client = chromadb.Client(settings)
Expand Down Expand Up @@ -106,7 +106,7 @@ def test_legacy_values() -> None:
chroma_api_impl="chromadb.api.local.LocalAPI",
persist_directory="./foo",
chroma_server_host="foo",
chroma_server_http_port="80",
chroma_server_http_port=80,
)
)
client.clear_system_cache()
4 changes: 2 additions & 2 deletions chromadb/test/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,9 @@ def test_http_client_with_inconsistent_host_settings() -> None:
def test_http_client_with_inconsistent_port_settings() -> None:
try:
chromadb.HttpClient(
port="8002",
port=8002,
settings=Settings(
chroma_server_http_port="8001",
chroma_server_http_port=8001,
),
)
except ValueError as e:
Expand Down