Skip to content

Commit

Permalink
Consolidate validators
Browse files Browse the repository at this point in the history
  • Loading branch information
spikechroma authored and atroyn committed Oct 2, 2024
1 parent a8fd123 commit 4e27cd5
Show file tree
Hide file tree
Showing 9 changed files with 159 additions and 75 deletions.
2 changes: 1 addition & 1 deletion chromadb/api/fastapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@
GetResult,
QueryResult,
CollectionMetadata,
validate_batch,
convert_np_embeddings_to_list,
validate_batch,
)
from chromadb.auth import (
ClientAuthProvider,
Expand Down
39 changes: 4 additions & 35 deletions chromadb/api/models/CollectionCommon.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,13 +41,13 @@
validate_ids,
validate_include,
validate_metadata,
validate_metadatas,
validate_embeddings,
validate_embedding_function,
validate_n_results,
validate_where,
validate_where_document,
record_set_contains_one_of,
validate_record_set,
)

# TODO: We should rename the types in chromadb.types to be Models where
Expand Down Expand Up @@ -169,37 +169,6 @@ def _unpack_record_set(
"uris": maybe_cast_one_to_many(uris),
}

@staticmethod
def _validate_record_set(
record_set: RecordSet,
require_data: bool,
) -> None:
validate_ids(record_set["ids"])
validate_embeddings(record_set["embeddings"]) if record_set[
"embeddings"
] is not None else None
validate_metadatas(record_set["metadatas"]) if record_set[
"metadatas"
] is not None else None

# Only one of documents or images can be provided
if record_set["documents"] is not None and record_set["images"] is not None:
raise ValueError("You can only provide documents or images, not both.")

required_fields: Include = ["embeddings", "documents", "images", "uris"] # type: ignore[list-item]
if not require_data:
required_fields += ["metadatas"] # type: ignore[list-item]

if not record_set_contains_one_of(record_set, include=required_fields):
raise ValueError(f"You must provide one of {', '.join(required_fields)}")

valid_ids = record_set["ids"]
for key in ["embeddings", "metadatas", "documents", "images", "uris"]:
if record_set[key] is not None and len(record_set[key]) != len(valid_ids): # type: ignore[literal-required]
raise ValueError(
f"Number of {key} {len(record_set[key])} must match number of ids {len(valid_ids)}" # type: ignore[literal-required]
)

def _compute_embeddings(
self,
documents: Optional[Documents],
Expand Down Expand Up @@ -406,7 +375,7 @@ def _process_add_request(
uris=uris,
)

self._validate_record_set(
validate_record_set(
record_set,
require_data=True,
)
Expand Down Expand Up @@ -443,7 +412,7 @@ def _process_upsert_request(
uris=uris,
)

self._validate_record_set(
validate_record_set(
record_set,
require_data=True,
)
Expand Down Expand Up @@ -481,7 +450,7 @@ def _process_update_request(
uris=uris,
)

self._validate_record_set(
validate_record_set(
record_set,
require_data=False,
)
Expand Down
99 changes: 75 additions & 24 deletions chromadb/api/segment.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from chromadb.errors import (
InvalidDimensionException,
InvalidCollectionException,
InvalidInputException,
VersionMismatchError,
)
from chromadb.api.types import (
Expand All @@ -34,13 +35,15 @@
Where,
WhereDocument,
Include,
RecordSet,
GetResult,
QueryResult,
validate_metadata,
validate_update_metadata,
validate_where,
validate_where_document,
validate_batch,
validate_record_set,
)
from chromadb.telemetry.product.events import (
CollectionAddEvent,
Expand Down Expand Up @@ -341,10 +344,20 @@ def _add(
self._quota.static_check(metadatas, documents, embeddings, str(collection_id))
coll = self._get_collection(collection_id)
self._manager.hint_use_collection(collection_id, t.Operation.ADD)
validate_batch(
(ids, embeddings, metadatas, documents, uris),
{"max_batch_size": self.get_max_batch_size()},

self._validate_record_set(
collection=coll,
record_set={
"ids": ids,
"embeddings": embeddings,
"documents": documents,
"uris": uris,
"metadatas": metadatas,
"images": None,
},
require_data=True,
)

records_to_submit = list(
_records(
t.Operation.ADD,
Expand All @@ -355,7 +368,6 @@ def _add(
uris=uris,
)
)
self._validate_embedding_record_set(coll, records_to_submit)
self._producer.submit_embeddings(collection_id, records_to_submit)

self._product_telemetry_client.capture(
Expand Down Expand Up @@ -383,10 +395,20 @@ def _update(
self._quota.static_check(metadatas, documents, embeddings, str(collection_id))
coll = self._get_collection(collection_id)
self._manager.hint_use_collection(collection_id, t.Operation.UPDATE)
validate_batch(
(ids, embeddings, metadatas, documents, uris),
{"max_batch_size": self.get_max_batch_size()},

self._validate_record_set(
collection=coll,
record_set={
"ids": ids,
"embeddings": embeddings,
"documents": documents,
"uris": uris,
"metadatas": metadatas,
"images": None,
},
require_data=False,
)

records_to_submit = list(
_records(
t.Operation.UPDATE,
Expand All @@ -397,7 +419,6 @@ def _update(
uris=uris,
)
)
self._validate_embedding_record_set(coll, records_to_submit)
self._producer.submit_embeddings(collection_id, records_to_submit)

self._product_telemetry_client.capture(
Expand Down Expand Up @@ -427,10 +448,20 @@ def _upsert(
self._quota.static_check(metadatas, documents, embeddings, str(collection_id))
coll = self._get_collection(collection_id)
self._manager.hint_use_collection(collection_id, t.Operation.UPSERT)
validate_batch(
(ids, embeddings, metadatas, documents, uris),
{"max_batch_size": self.get_max_batch_size()},

self._validate_record_set(
collection=coll,
record_set={
"ids": ids,
"embeddings": embeddings,
"documents": documents,
"uris": uris,
"metadatas": metadatas,
"images": None,
},
require_data=True,
)

records_to_submit = list(
_records(
t.Operation.UPSERT,
Expand All @@ -441,7 +472,6 @@ def _upsert(
uris=uris,
)
)
self._validate_embedding_record_set(coll, records_to_submit)
self._producer.submit_embeddings(collection_id, records_to_submit)

return True
Expand Down Expand Up @@ -630,7 +660,6 @@ def _delete(
records_to_submit = list(
_records(operation=t.Operation.DELETE, ids=ids_to_delete)
)
self._validate_embedding_record_set(coll, records_to_submit)
self._producer.submit_embeddings(collection_id, records_to_submit)

self._product_telemetry_client.capture(
Expand Down Expand Up @@ -851,19 +880,41 @@ def get_max_batch_size(self) -> int:
# system, since the cache is only local.
# TODO: promote collection -> topic to a base class method so that it can be
# used for channel assignment in the distributed version of the system.
@trace_method(
"SegmentAPI._validate_embedding_record_set", OpenTelemetryGranularity.ALL
)
def _validate_embedding_record_set(
self, collection: t.Collection, records: List[t.OperationRecord]
@trace_method("SegmentAPI._validate_record_set", OpenTelemetryGranularity.ALL)
def _validate_record_set(
self,
collection: t.Collection,
record_set: RecordSet,
require_data: bool,
) -> None:
"""Validate the dimension of an embedding record before submitting it to the system."""
add_attributes_to_current_span({"collection_id": str(collection["id"])})
for record in records:
if record["embedding"] is not None:
self._validate_dimension(
collection, len(record["embedding"]), update=True
)

try:
validate_record_set(record_set, require_data=require_data)
validate_batch(
(
record_set["ids"],
record_set["embeddings"],
record_set["metadatas"],
record_set["documents"],
record_set["uris"],
),
{"max_batch_size": self.get_max_batch_size()},
)

if require_data and record_set["embeddings"] is None:
raise ValueError("You must provide embeddings")

if record_set["embeddings"] is not None:
"""Validate the dimension of an embedding record before submitting it to the system."""
for embedding in record_set["embeddings"]:
if embedding:
self._validate_dimension(
collection, len(embedding), update=True
)

except ValueError as e:
raise InvalidInputException(f"{e}")

# This method is intentionally left untraced because otherwise it can emit thousands of spans for requests containing many embeddings.
def _validate_dimension(
Expand Down
33 changes: 32 additions & 1 deletion chromadb/api/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def maybe_cast_one_to_many(target: Optional[(OneOrMany[T])]) -> Optional[List[T]


def maybe_cast_one_to_many_embedding(
target: Union[OneOrMany[Embedding], OneOrMany[PyEmbedding]]
target: Union[Optional[OneOrMany[Embedding]], Optional[OneOrMany[PyEmbedding]]]
) -> Optional[Embeddings]:
if target is None:
return None
Expand Down Expand Up @@ -616,6 +616,37 @@ def validate_batch(
)


def validate_record_set(
record_set: RecordSet,
require_data: bool,
) -> None:
validate_ids(record_set["ids"])
validate_embeddings(record_set["embeddings"]) if record_set[
"embeddings"
] is not None else None
validate_metadatas(record_set["metadatas"]) if record_set[
"metadatas"
] is not None else None

# Only one of documents or images can be provided
if record_set["documents"] is not None and record_set["images"] is not None:
raise ValueError("You can only provide documents or images, not both.")

required_fields: Include = ["embeddings", "documents", "images", "uris"] # type: ignore[list-item]
if not require_data:
required_fields += ["metadatas"] # type: ignore[list-item]

if not record_set_contains_one_of(record_set, include=required_fields):
raise ValueError(f"You must provide one of {', '.join(required_fields)}")

valid_ids = record_set["ids"]
for key in ["embeddings", "metadatas", "documents", "images", "uris"]:
if record_set[key] is not None and len(record_set[key]) != len(valid_ids): # type: ignore[literal-required]
raise ValueError(
f"Number of {key} {len(record_set[key])} must match number of ids {len(valid_ids)}" # type: ignore[literal-required]
)


def convert_np_embeddings_to_list(embeddings: Embeddings) -> PyEmbeddings:
return [embedding.tolist() for embedding in embeddings]

Expand Down
8 changes: 8 additions & 0 deletions chromadb/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,13 @@ def name(cls) -> str:
pass


class InvalidInputException(ChromaError):
@classmethod
@overrides
def name(cls) -> str:
return "InvalidInput"


class InvalidDimensionException(ChromaError):
@classmethod
@overrides
Expand Down Expand Up @@ -137,6 +144,7 @@ def name(cls) -> str:


error_types: Dict[str, Type[ChromaError]] = {
"InvalidInput": InvalidInputException,
"InvalidDimension": InvalidDimensionException,
"InvalidCollection": InvalidCollectionException,
"IDAlreadyExists": IDAlreadyExistsError,
Expand Down
36 changes: 33 additions & 3 deletions chromadb/test/api/test_validations.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
RecordSet,
record_set_contains_one_of,
maybe_cast_one_to_many_embedding,
validate_embeddings,
Embeddings,
)


Expand Down Expand Up @@ -53,13 +55,13 @@ def test_maybe_cast_one_to_many_embedding() -> None:
assert maybe_cast_one_to_many_embedding(None) is None

# Test with a single embedding as a list
single_embedding = [1.0, 2.0, 3.0]
single_embedding = np.array([1.0, 2.0, 3.0])
result = maybe_cast_one_to_many_embedding(single_embedding)
assert result == [single_embedding]

# Test with multiple embeddings as a list of lists
multiple_embeddings = [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]
result = maybe_cast_one_to_many_embedding(multiple_embeddings) # type: ignore[arg-type]
multiple_embeddings = [np.array([1.0, 2.0, 3.0]), np.array([4.0, 5.0, 6.0])]
result = maybe_cast_one_to_many_embedding(multiple_embeddings)
assert result == multiple_embeddings

# Test with a numpy array (single embedding)
Expand Down Expand Up @@ -96,3 +98,31 @@ def test_maybe_cast_one_to_many_embedding() -> None:
match="Expected embeddings to be a list or a numpy array, got str",
):
maybe_cast_one_to_many_embedding("") # type: ignore[arg-type]


def test_embeddings_validation() -> None:
invalid_embeddings = [[0, 0, True], [1.2, 2.24, 3.2]]

with pytest.raises(ValueError) as e:
validate_embeddings(invalid_embeddings) # type: ignore[arg-type]

assert "Expected each value in the embedding to be a int or float" in str(e)

invalid_embeddings = [[0, 0, "invalid"], [1.2, 2.24, 3.2]]

with pytest.raises(ValueError) as e:
validate_embeddings(invalid_embeddings) # type: ignore[arg-type]

assert "Expected each value in the embedding to be a int or float" in str(e)

with pytest.raises(ValueError) as e:
validate_embeddings("invalid") # type: ignore[arg-type]

assert "Expected embeddings to be a list, got str" in str(e)


def test_0dim_embedding_validation() -> None:
embds: Embeddings = [[]] # type: ignore[list-item]
with pytest.raises(ValueError) as e:
validate_embeddings(embds)
assert "Expected each embedding in the embeddings to be a non-empty list" in str(e)
Loading

0 comments on commit 4e27cd5

Please sign in to comment.