Skip to content

Commit

Permalink
[ENH] Boolean type support + bug fixes + enable test_filtering (#2324)
Browse files Browse the repository at this point in the history
## Description of changes

*Summarize the changes made by this PR.*
 - Improvements & Bug fixes
Adds support for boolean metadata type for both add() and query()

## Test plan
- [x] Tests pass locally with `pytest` for python, `yarn test` for js,
`cargo test` for rust

## Documentation Changes
None
  • Loading branch information
sanketkedia committed Jun 14, 2024
1 parent 6c6cb55 commit ef75293
Show file tree
Hide file tree
Showing 33 changed files with 1,509 additions and 649 deletions.
1 change: 1 addition & 0 deletions .github/workflows/_python-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ jobs:
test-globs: ["chromadb/test/db/test_system.py",
"chromadb/test/property/test_collections.py",
"chromadb/test/property/test_add.py",
"chromadb/test/property/test_filtering.py",
"chromadb/test/property/test_collections_with_database_tenant.py",
"chromadb/test/property/test_collections_with_database_tenant_overwrite.py",
"chromadb/test/ingest/test_producer_consumer.py",
Expand Down
33 changes: 26 additions & 7 deletions chromadb/db/mixins/sysdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,7 @@ def get_segments(
metadata_t.str_value,
metadata_t.int_value,
metadata_t.float_value,
metadata_t.bool_value,
)
.left_join(metadata_t)
.on(segments_t.id == metadata_t.segment_id)
Expand Down Expand Up @@ -384,6 +385,7 @@ def get_collections(
metadata_t.str_value,
metadata_t.int_value,
metadata_t.float_value,
metadata_t.bool_value,
)
.left_join(metadata_t)
.on(collections_t.id == metadata_t.collection_id)
Expand Down Expand Up @@ -636,15 +638,17 @@ def _metadata_from_rows(
"num_rows": len(rows),
}
)
metadata: Dict[str, Union[str, int, float]] = {}
metadata: Dict[str, Union[str, int, float, bool]] = {}
for row in rows:
key = str(row[-4])
if row[-3] is not None:
metadata[key] = str(row[-3])
key = str(row[-5])
if row[-4] is not None:
metadata[key] = str(row[-4])
elif row[-3] is not None:
metadata[key] = int(row[-3])
elif row[-2] is not None:
metadata[key] = int(row[-2])
metadata[key] = float(row[-2])
elif row[-1] is not None:
metadata[key] = float(row[-1])
metadata[key] = bool(row[-1])
return metadata or None

@trace_method("SqlSysDB._insert_metadata", OpenTelemetryGranularity.ALL)
Expand Down Expand Up @@ -685,17 +689,30 @@ def _insert_metadata(
table.str_value,
table.int_value,
table.float_value,
table.bool_value,
)
)
sql_id = self.uuid_to_db(id)
for k, v in metadata.items():
if isinstance(v, str):
# Note: The order is important here because isinstance(v, bool)
# and isinstance(v, int) both are true for v of bool type.
if isinstance(v, bool):
q = q.insert(
ParameterValue(sql_id),
ParameterValue(k),
None,
None,
None,
ParameterValue(int(v)),
)
elif isinstance(v, str):
q = q.insert(
ParameterValue(sql_id),
ParameterValue(k),
ParameterValue(v),
None,
None,
None,
)
elif isinstance(v, int):
q = q.insert(
Expand All @@ -704,6 +721,7 @@ def _insert_metadata(
None,
ParameterValue(v),
None,
None,
)
elif isinstance(v, float):
q = q.insert(
Expand All @@ -712,6 +730,7 @@ def _insert_metadata(
None,
None,
ParameterValue(v),
None,
)
elif v is None:
continue
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
-- SQLite does not support adding check with alter table, as a result, adding a check
-- involve creating a new table and copying the data over. It is over kill with adding
-- a boolean type column. The application write to the table needs to ensure the data
-- integrity.
ALTER TABLE collection_metadata ADD COLUMN bool_value INTEGER;
ALTER TABLE segment_metadata ADD COLUMN bool_value INTEGER;
156 changes: 80 additions & 76 deletions chromadb/proto/chroma_pb2.py

Large diffs are not rendered by default.

30 changes: 26 additions & 4 deletions chromadb/proto/chroma_pb2.pyi

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

15 changes: 11 additions & 4 deletions chromadb/proto/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,9 +87,11 @@ def _from_proto_metadata_handle_none(
) -> Optional[Union[UpdateMetadata, Metadata]]:
if not metadata.metadata:
return None
out_metadata: Dict[str, Union[str, int, float, None]] = {}
out_metadata: Dict[str, Union[str, int, float, bool, None]] = {}
for key, value in metadata.metadata.items():
if value.HasField("string_value"):
if value.HasField("bool_value"):
out_metadata[key] = value.bool_value
elif value.HasField("string_value"):
out_metadata[key] = value.string_value
elif value.HasField("int_value"):
out_metadata[key] = value.int_value
Expand Down Expand Up @@ -174,9 +176,14 @@ def to_proto_segment_scope(segment_scope: SegmentScope) -> proto.SegmentScope:


def to_proto_metadata_update_value(
value: Union[str, int, float, None]
value: Union[str, int, float, bool, None]
) -> proto.UpdateMetadataValue:
if isinstance(value, str):
# Be careful with the order here. Since bools are a subtype of int in python,
# isinstance(value, bool) and isinstance(value, int) both return true
# for a value of bool type.
if isinstance(value, bool):
return proto.UpdateMetadataValue(bool_value=value)
elif isinstance(value, str):
return proto.UpdateMetadataValue(string_value=value)
elif isinstance(value, int):
return proto.UpdateMetadataValue(int_value=value)
Expand Down
35 changes: 31 additions & 4 deletions chromadb/segment/impl/metadata/grpc_segment.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,10 +76,12 @@ def get_metadata(

request: pb.QueryMetadataRequest = pb.QueryMetadataRequest(
segment_id=self._segment["id"].hex,
where=self._where_to_proto(where) if where is not None else None,
where=self._where_to_proto(where)
if where is not None and len(where) > 0
else None,
where_document=(
self._where_document_to_proto(where_document)
if where_document is not None
if where_document is not None and len(where_document) > 0
else None
),
ids=ids,
Expand Down Expand Up @@ -144,6 +146,11 @@ def _where_to_proto(self, where: Optional[Where]) -> pb.Where:
ssc.value = value
ssc.comparator = pb.GenericComparator.EQ
dc.single_string_operand.CopyFrom(ssc)
elif type(value) is bool:
sbc = pb.SingleBoolComparison()
sbc.value = value
sbc.comparator = pb.GenericComparator.EQ
dc.single_bool_operand.CopyFrom(sbc)
elif type(value) is int:
sic = pb.SingleIntComparison()
sic.value = value
Expand Down Expand Up @@ -183,6 +190,12 @@ def _where_to_proto(self, where: Optional[Where]) -> pb.Where:
slo.values.extend([x]) # type: ignore
slo.list_operator = list_operator
dc.string_list_operand.CopyFrom(slo)
elif type(operand[0]) is bool:
blo = pb.BoolListComparison()
for x in operand:
blo.values.extend([x]) # type: ignore
blo.list_operator = list_operator
dc.bool_list_operand.CopyFrom(blo)
elif type(operand[0]) is int:
ilo = pb.IntListComparison()
for x in operand:
Expand Down Expand Up @@ -213,6 +226,18 @@ def _where_to_proto(self, where: Optional[Where]) -> pb.Where:
f"Expected where operator to be $eq or $ne, got {operator}"
)
dc.single_string_operand.CopyFrom(ssc)
elif type(operand) is bool:
sbc = pb.SingleBoolComparison()
sbc.value = operand
if operator == "$eq":
sbc.comparator = pb.GenericComparator.EQ
elif operator == "$ne":
sbc.comparator = pb.GenericComparator.NE
else:
raise ValueError(
f"Expected where operator to be $eq or $ne, got {operator}"
)
dc.single_bool_operand.CopyFrom(sbc)
elif type(operand) is int:
sic = pb.SingleIntComparison()
sic.value = operand
Expand Down Expand Up @@ -316,10 +341,12 @@ def _where_document_to_proto(
def _from_proto(
self, record: pb.MetadataEmbeddingRecord
) -> MetadataEmbeddingRecord:
translated_metadata: Dict[str, str | int | float] = {}
translated_metadata: Dict[str, str | int | float | bool] = {}
record_metadata_map = record.metadata.metadata
for key, value in record_metadata_map.items():
if value.HasField("string_value"):
if value.HasField("bool_value"):
translated_metadata[key] = value.bool_value
elif value.HasField("string_value"):
translated_metadata[key] = value.string_value
elif value.HasField("int_value"):
translated_metadata[key] = value.int_value
Expand Down
6 changes: 6 additions & 0 deletions chromadb/test/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,12 @@
hypothesis.settings.load_profile(CURRENT_PRESET)


def reset(api: ServerAPI) -> None:
api.reset()
if not NOT_CLUSTER_ONLY:
time.sleep(MEMBERLIST_SLEEP)


def override_hypothesis_profile(
fast: Optional[hypothesis.settings] = None,
normal: Optional[hypothesis.settings] = None,
Expand Down
25 changes: 23 additions & 2 deletions chromadb/test/property/strategies.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from hypothesis.strategies._internal.strategies import SearchStrategy
from hypothesis.errors import InvalidDefinition
from hypothesis.stateful import RuleBasedStateMachine
from chromadb.test.conftest import NOT_CLUSTER_ONLY

from dataclasses import dataclass

Expand Down Expand Up @@ -112,10 +113,13 @@ class Record(TypedDict):
safe_integers = st.integers(
min_value=-(2**31), max_value=2**31 - 1
) # TODO: handle longs
# In distributed chroma, floats are 32 bit hence we need to
# restrict the generation to generate only 32 bit floats.
safe_floats = st.floats(
allow_infinity=False,
allow_nan=False,
allow_subnormal=False,
width=32,
min_value=-1e6,
max_value=1e6,
) # TODO: handle infinity and NAN
Expand Down Expand Up @@ -523,12 +527,21 @@ def where_clause(draw: st.DrawFn, collection: Collection) -> types.Where:
key = draw(st.sampled_from(known_keys))
value = collection.known_metadata_keys[key]

legal_ops: List[Optional[str]] = [None, "$eq", "$ne", "$in", "$nin"]
# This is hacky, but the distributed system does not support $in or $in so we
# need to avoid generating these operators for now in that case.
# TODO: Remove this once the distributed system supports $in and $nin
if not NOT_CLUSTER_ONLY:
legal_ops: List[Optional[str]] = [None, "$eq"]
else:
legal_ops: List[Optional[str]] = [None, "$eq", "$ne", "$in", "$nin"]

if not isinstance(value, str) and not isinstance(value, bool):
legal_ops.extend(["$gt", "$lt", "$lte", "$gte"])
if isinstance(value, float):
# Add or subtract a small number to avoid floating point rounding errors
value = value + draw(st.sampled_from([1e-6, -1e-6]))
# Truncate to 32 bit
value = float(np.float32(value))

op: WhereOperator = draw(st.sampled_from(legal_ops))

Expand All @@ -554,7 +567,15 @@ def where_doc_clause(draw: st.DrawFn, collection: Collection) -> types.WhereDocu
else:
word = draw(safe_text)

op: WhereOperator = draw(st.sampled_from(["$contains", "$not_contains"]))
# This is hacky, but the distributed system does not support $not_contains
# so we need to avoid generating these operators for now in that case.
# TODO: Remove this once the distributed system supports $not_contains
op: WhereOperator
if not NOT_CLUSTER_ONLY:
op = draw(st.sampled_from(["$contains"]))
else:
op = draw(st.sampled_from(["$contains", "$not_contains"]))

if op == "$contains":
return {"$contains": word}
else:
Expand Down
9 changes: 1 addition & 8 deletions chromadb/test/property/test_add.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,12 @@
from typing import cast, List, Any, Dict
import hypothesis
import pytest
import time
import hypothesis.strategies as st
from hypothesis import given, settings
from chromadb.api import ServerAPI
from chromadb.api.types import Embeddings, Metadatas
from chromadb.test.conftest import (
MEMBERLIST_SLEEP,
reset,
NOT_CLUSTER_ONLY,
override_hypothesis_profile,
)
Expand All @@ -23,12 +22,6 @@
collection_st = st.shared(strategies.collections(with_hnsw_params=True), key="coll")


def reset(api: ServerAPI) -> None:
api.reset()
if not NOT_CLUSTER_ONLY:
time.sleep(MEMBERLIST_SLEEP)


@given(
collection=collection_st,
record_set=strategies.recordsets(collection_st),
Expand Down
Loading

0 comments on commit ef75293

Please sign in to comment.