Skip to content

Commit

Permalink
Default None, type fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
jeffchuber authored and atroyn committed Jun 28, 2024
1 parent 8fb314d commit 55951e7
Show file tree
Hide file tree
Showing 8 changed files with 29 additions and 14 deletions.
4 changes: 2 additions & 2 deletions chromadb/api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -553,7 +553,7 @@ def list_collections(
def create_collection(
self,
name: str,
configuration: Optional[CollectionConfiguration],
configuration: Optional[CollectionConfiguration] = None,
metadata: Optional[CollectionMetadata] = None,
get_or_create: bool = False,
tenant: str = DEFAULT_TENANT,
Expand All @@ -575,7 +575,7 @@ def get_collection(
def get_or_create_collection(
self,
name: str,
configuration: Optional[CollectionConfiguration],
configuration: Optional[CollectionConfiguration] = None,
metadata: Optional[CollectionMetadata] = None,
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
Expand Down
8 changes: 4 additions & 4 deletions chromadb/api/async_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ async def _get(
page: Optional[int] = None,
page_size: Optional[int] = None,
where_document: Optional[WhereDocument] = {},
include: Include = ["embeddings", "metadatas", "documents"],
include: Include = ["embeddings", "metadatas", "documents"], # type: ignore[list-item]
) -> GetResult:
"""[Internal] Returns entries from a collection specified by UUID.
Expand Down Expand Up @@ -264,7 +264,7 @@ async def _query(
n_results: int = 10,
where: Where = {},
where_document: WhereDocument = {},
include: Include = ["embeddings", "metadatas", "documents", "distances"],
include: Include = ["embeddings", "metadatas", "documents", "distances"], # type: ignore[list-item]
) -> QueryResult:
"""[Internal] Performs a nearest neighbors query on a collection specified by UUID.
Expand Down Expand Up @@ -544,7 +544,7 @@ async def count_collections(
async def create_collection(
self,
name: str,
configuration: Optional[CollectionConfiguration],
configuration: Optional[CollectionConfiguration] = None,
metadata: Optional[CollectionMetadata] = None,
get_or_create: bool = False,
tenant: str = DEFAULT_TENANT,
Expand All @@ -566,7 +566,7 @@ async def get_collection(
async def get_or_create_collection(
self,
name: str,
configuration: Optional[CollectionConfiguration],
configuration: Optional[CollectionConfiguration] = None,
metadata: Optional[CollectionMetadata] = None,
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
Expand Down
10 changes: 5 additions & 5 deletions chromadb/api/segment.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ def get_tenant(self, name: str) -> t.Tenant:
def create_collection(
self,
name: str,
configuration: Optional[CollectionConfiguration],
configuration: Optional[CollectionConfiguration] = None,
metadata: Optional[CollectionMetadata] = None,
get_or_create: bool = False,
tenant: str = DEFAULT_TENANT,
Expand Down Expand Up @@ -213,12 +213,12 @@ def create_collection(
def get_or_create_collection(
self,
name: str,
configuration: Optional[CollectionConfiguration],
configuration: Optional[CollectionConfiguration] = None,
metadata: Optional[CollectionMetadata] = None,
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
) -> CollectionModel:
return self.create_collection( # type: ignore
return self.create_collection(
name=name,
metadata=metadata,
configuration=configuration,
Expand Down Expand Up @@ -459,7 +459,7 @@ def _get(
page: Optional[int] = None,
page_size: Optional[int] = None,
where_document: Optional[WhereDocument] = {},
include: Include = ["embeddings", "metadatas", "documents"],
include: Include = ["embeddings", "metadatas", "documents"], # type: ignore[list-item]
) -> GetResult:
add_attributes_to_current_span(
{
Expand Down Expand Up @@ -638,7 +638,7 @@ def _query(
n_results: int = 10,
where: Where = {},
where_document: WhereDocument = {},
include: Include = ["documents", "metadatas", "distances"],
include: Include = ["documents", "metadatas", "distances"], # type: ignore[list-item]
) -> QueryResult:
add_attributes_to_current_span(
{
Expand Down
1 change: 0 additions & 1 deletion chromadb/server/fastapi/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,7 +376,6 @@ def process_create_database(
tenant: str, headers: Headers, raw_body: bytes
) -> None:
db = CreateDatabase.model_validate(orjson.loads(raw_body))

(
maybe_tenant,
maybe_database,
Expand Down
2 changes: 2 additions & 0 deletions clients/js/src/ChromaClient.ts
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,7 @@ export class ChromaClient {
this.database,
{
name,
configuration: null, //TODO: Configuration type in JavaScript
metadata,
},
this.api.options,
Expand Down Expand Up @@ -226,6 +227,7 @@ export class ChromaClient {
{
name,
metadata,
configuration: null,
get_or_create: true,
},
this.api.options,
Expand Down
3 changes: 3 additions & 0 deletions clients/js/src/generated/models.ts
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ export namespace Api {

export interface CreateCollection {
name: string;
configuration: Api.CreateCollection.Configuration | null;
metadata?: Api.CreateCollection.Metadata | null;
get_or_create?: boolean;
}
Expand All @@ -52,6 +53,8 @@ export namespace Api {
* @namespace CreateCollection
*/
export namespace CreateCollection {
export interface Configuration {}

export interface Metadata {}
}

Expand Down
1 change: 1 addition & 0 deletions clients/js/src/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ export type CollectionType = {
name: string;
id: string;
metadata: Metadata | null;
configuration_json: any;
};

export type GetResponse = {
Expand Down
14 changes: 12 additions & 2 deletions clients/js/test/collection.client.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ test("it should create a collection", async () => {
expect(collection).toHaveProperty("id");
expect(collection.name).toBe("test");
let collections = await chroma.listCollections();

expect([
{
name: "test",
Expand All @@ -32,7 +33,12 @@ test("it should create a collection", async () => {
version: 0,
dimension: null,
},
]).toEqual(expect.arrayContaining(collections));
]).toEqual(
expect.arrayContaining(
collections.map(({ configuration_json, ...rest }) => rest),
),
);

expect([{ name: "test2", metadata: null }]).not.toEqual(
expect.arrayContaining(collections),
);
Expand Down Expand Up @@ -60,7 +66,11 @@ test("it should create a collection", async () => {
dimension: null,
version: 0,
},
]).toEqual(expect.arrayContaining(collections2));
]).toEqual(
expect.arrayContaining(
collections2.map(({ configuration_json, ...rest }) => rest),
),
);
});

test("it should get a collection", async () => {
Expand Down

0 comments on commit 55951e7

Please sign in to comment.