Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[STACKED][ENH] Metadata Indexing - Implementation #1443

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 19 additions & 1 deletion chromadb/api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
Where,
QueryResult,
GetResult,
WhereDocument,
WhereDocument, SqlBackedIndex,
)
from chromadb.config import Component, Settings
from chromadb.types import Database, Tenant
Expand Down Expand Up @@ -412,6 +412,24 @@ def max_batch_size(self) -> int:
to submit_embeddings."""
pass

@abstractmethod
def _create_collection_indices(self, collection_id: UUID, indices: Sequence[SqlBackedIndex]) -> None:
"""Create a new index """
pass

@abstractmethod
def _drop_collection_indices(self, collection_id: UUID, index_names: Optional[Sequence[str]]) -> None:
pass

@abstractmethod
def _list_collection_indices(self, collection_id: UUID) -> Sequence[SqlBackedIndex]:
pass

@abstractmethod
def _rebuild_collection_indices(self, collection_id: UUID, index_names: Optional[Sequence[str]]) -> None:
pass



class ClientAPI(BaseAPI, ABC):
tenant: str
Expand Down
18 changes: 17 additions & 1 deletion chromadb/api/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
Loadable,
Metadatas,
QueryResult,
URIs,
URIs, SqlBackedIndex,
)
from chromadb.config import Settings, System
from chromadb.config import DEFAULT_TENANT, DEFAULT_DATABASE
Expand Down Expand Up @@ -444,6 +444,22 @@ def _validate_tenant_database(self, tenant: str, database: str) -> None:
raise ValueError(
f"Could not connect to database {database} for tenant {tenant}. Are you sure it exists?"
)
@override
def _create_collection_indices(self, collection_id: UUID, indices: Sequence[SqlBackedIndex]) -> None:
"""Create a new index """
pass

@override
def _drop_collection_indices(self, collection_id: UUID, index_names: Optional[Sequence[str]]) -> None:
pass

@override
def _list_collection_indices(self, collection_id: UUID) -> Sequence[SqlBackedIndex]:
pass

@override
def _rebuild_collection_indices(self, collection_id: UUID, index_names: Optional[Sequence[str]]) -> None:
pass

# endregion

Expand Down
19 changes: 18 additions & 1 deletion chromadb/api/fastapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
GetResult,
QueryResult,
CollectionMetadata,
validate_batch,
validate_batch, SqlBackedIndex,
)
from chromadb.auth import (
ClientAuthProvider,
Expand Down Expand Up @@ -605,6 +605,23 @@ def max_batch_size(self) -> int:
self._max_batch_size = cast(int, resp.json()["max_batch_size"])
return self._max_batch_size

@override
def _create_collection_indices(self, collection_id: UUID, indices: Sequence[SqlBackedIndex]) -> None:
"""Create a new index """
pass

@override
def _drop_collection_indices(self, collection_id: UUID, index_names: Optional[Sequence[str]]) -> None:
pass

@override
def _list_collection_indices(self, collection_id: UUID) -> Sequence[SqlBackedIndex]:
pass

@override
def _rebuild_collection_indices(self, collection_id: UUID, index_names: Optional[Sequence[str]]) -> None:
pass


def raise_chroma_error(resp: requests.Response) -> None:
"""Raises an error if the response is not ok, using a ChromaError if possible"""
Expand Down
16 changes: 14 additions & 2 deletions chromadb/api/models/Collection.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import TYPE_CHECKING, Optional, Tuple, Any
from typing import TYPE_CHECKING, Optional, Tuple, Any, Sequence
from pydantic import BaseModel, PrivateAttr

from uuid import UUID
Expand Down Expand Up @@ -42,7 +42,7 @@
validate_where_document,
validate_n_results,
validate_embeddings,
validate_embedding_function,
validate_embedding_function, SqlBackedIndex,
)
import logging

Expand Down Expand Up @@ -585,3 +585,15 @@ def _embed(self, input: Any) -> Embeddings:
"https://docs.trychroma.com/embeddings"
)
return self._embedding_function(input=input)

def add_indices(self, indices: Sequence[SqlBackedIndex]) -> None:
return self._client._create_collection_indices(self.id, indices)

def rebuild_indices(self, indices: Sequence[str]=None) -> None:
return self._client._rebuild_collection_indices(self.id, indices)

def drop_indices(self, indices: Optional[Sequence[str]]= None) -> None:
return self._client._drop_collection_indices(self.id, indices)

def list_indices(self) -> Sequence[SqlBackedIndex]:
return self._client._list_collection_indices(self.id)
20 changes: 19 additions & 1 deletion chromadb/api/segment.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
validate_update_metadata,
validate_where,
validate_where_document,
validate_batch,
validate_batch, SqlBackedIndex,
)
from chromadb.telemetry.product.events import (
CollectionAddEvent,
Expand Down Expand Up @@ -768,6 +768,24 @@ def get_settings(self) -> Settings:
def max_batch_size(self) -> int:
return self._producer.max_batch_size

@override
def _create_collection_indices(self, collection_id: UUID, indices: Sequence[SqlBackedIndex]) -> None:
"""Create a new index """
return self._sysdb.create_indices(collection_id, indices)

@override
def _drop_collection_indices(self, collection_id: UUID, index_names: Optional[Sequence[str]]) -> None:
return self._sysdb.drop_indices(collection_id, index_names)

@override
def _list_collection_indices(self, collection_id: UUID) -> Sequence[SqlBackedIndex]:
return self._sysdb.list_indices(collection_id)

@override
def _rebuild_collection_indices(self, collection_id: UUID, index_names: Optional[Sequence[str]]) -> None:
return self._sysdb.rebuild_indices(collection_id, index_names)


# TODO: This could potentially cause race conditions in a distributed version of the
# system, since the cache is only local.
# TODO: promote collection -> topic to a base class method so that it can be
Expand Down
16 changes: 15 additions & 1 deletion chromadb/api/types.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Optional, Sequence, Union, TypeVar, List, Dict, Any, Tuple, cast
from enum import Enum
from typing import Optional, Sequence, Union, TypeVar, List, Dict, Any, Tuple, cast, Set
from numpy.typing import NDArray
import numpy as np
from typing_extensions import Literal, TypedDict, Protocol
Expand Down Expand Up @@ -188,6 +189,19 @@ class IndexMetadata(TypedDict):
total_elements_added: int
time_created: float

AllowedIndexColumns = Literal["string_value", "int_value", "float_value"]


class IndexType(str, Enum):
METADATA = "metadata"
DOCUMENT = "document"


class SqlBackedIndex(TypedDict):
name: str
columns: Set[AllowedIndexColumns]
keys: Optional[Set[str]]
index_type: IndexType

Embeddable = Union[Documents, Images]
D = TypeVar("D", bound=Embeddable, contravariant=True)
Expand Down
98 changes: 97 additions & 1 deletion chromadb/db/mixins/sysdb.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from typing import Optional, Sequence, Any, Tuple, cast, Dict, Union, Set
from uuid import UUID
from overrides import override
from pypika import Table, Column
from pypika import Table, Column, Criterion
from itertools import groupby

from chromadb.api import SqlBackedIndex
from chromadb.config import DEFAULT_DATABASE, DEFAULT_TENANT, System
from chromadb.db.base import (
Cursor,
Expand All @@ -14,6 +15,7 @@
UniqueConstraintError,
)
from chromadb.db.system import SysDB
from chromadb.db.utils import IndexQuery
from chromadb.telemetry.opentelemetry import (
add_attributes_to_current_span,
OpenTelemetryClient,
Expand Down Expand Up @@ -736,3 +738,97 @@ def _insert_metadata(
sql, params = get_sql(q, self.parameter_format())
if sql:
cur.execute(sql, params)

@override
def create_indices(self, collection_id: UUID, indices: Sequence[SqlBackedIndex]) -> None:
segments = Table('segments')
embedding_metadata = Table('embedding_metadata')

segment_id_q = (self.querybuilder()
.from_(segments)
.select(segments.id)
.where(segments.collection == ParameterValue(self.uuid_to_db(collection_id)))
.where(segments.scope == 'METADATA'))

with self.tx() as cur:
segment_id_sql, segment_id_params = get_sql(segment_id_q, self.parameter_format())
cur.execute(segment_id_sql, segment_id_params)
segment_id = cur.fetchone()[0]

for index in indices:
q = (IndexQuery()
.create_index(f'cust_idx_{index["name"]}')
.on(embedding_metadata)
.columns(embedding_metadata.key, embedding_metadata.segment_id,
*(embedding_metadata[col] for col in index["columns"]))
.where(embedding_metadata.key != 'chroma:document')
.where(embedding_metadata.segment_id == segment_id))
if 'keys' in index:
where_clause = None
for key in index["keys"]:
condition = (embedding_metadata.key == key)
where_clause = condition if where_clause is None else (where_clause | condition)
q = q.where(Criterion.any([(embedding_metadata.key == key) for key in index["keys"]]))
print(q.get_sql())
cur.execute(q.get_sql())

@override
def drop_indices(self, collection_id: UUID, index_names: Optional[Sequence[str]]) -> None:
sql_master = Table('sqlite_master')

with self.tx() as cur:
for index_name in index_names:
query_indices = (self.querybuilder()
.from_(sql_master)
.select(sql_master.name)
.where(sql_master.type == 'index')
.where(sql_master.tbl_name == 'embedding_metadata')
.where(sql_master.sql.like(f'%cust_idx_{index_name}%')))

# Execute query to find index
cur.execute(query_indices.get_sql())
index_to_drop = cur.fetchone()

# Check if index exists
if not index_to_drop:
raise NotFoundError(f"Index {index_name} not found")

# Build query to drop index
q = IndexQuery().drop_index(index_to_drop[0])

# Execute query to drop index
cur.execute(q.get_sql())


@override
def rebuild_indices(self, collection_id: UUID, index_names: Optional[Sequence[str]]) -> None:
# Define the table for querying index information
sql_master = Table('sqlite_master') # or appropriate table for your database
# Start a transaction
with self.tx() as cur:
for index_name in index_names:
# Query to find the index
query_indices = (self.querybuilder()
.from_(sql_master)
.select(sql_master.name, sql_master.sql)
.where(sql_master.type == 'index')
.where(sql_master.tbl_name == 'your_table_name') # adjust as needed
.where(sql_master.name == f'cust_idx_{index_name}'))
cur.execute(query_indices.get_sql())
index_info = cur.fetchone()
if index_info:
index_name, index_definition = index_info
drop_query = IndexQuery().drop_index(index_name)
cur.execute(drop_query.get_sql())
cur.execute(index_definition)
print(f"Rebuilt index: {index_name}")
else:
# Optionally, log that the index was not found
raise NotFoundError(f"Index {index_name} not found")



@override
def list_indices(self, collection_id: UUID) -> Sequence[SqlBackedIndex]:
print("Getting indices")
return []
22 changes: 22 additions & 0 deletions chromadb/db/system.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from abc import abstractmethod
from typing import Optional, Sequence, Tuple
from uuid import UUID

from chromadb.api import SqlBackedIndex
from chromadb.types import (
Collection,
Database,
Expand Down Expand Up @@ -134,3 +136,23 @@ def update_collection(
keys with None values will be removed and keys not present in the UpdateMetadata
dict will be left unchanged."""
pass

@abstractmethod
def create_indices(self, collection_id: UUID, indices: Sequence[SqlBackedIndex]) -> None:
"""Create indices for a collection"""
pass

@abstractmethod
def drop_indices(self, collection_id: UUID, index_names: Optional[Sequence[str]]) -> None:
"""Delete all or select indices for a collection"""
pass

@abstractmethod
def rebuild_indices(self, collection_id: UUID, index_names: Optional[Sequence[str]]) -> None:
"""Rebuild all or select indices for a collection"""
pass

@abstractmethod
def list_indices(self, collection_id: UUID) -> Sequence[SqlBackedIndex]:
"""Get all indices for a collection"""
pass
Loading
Loading