diff --git a/chromadb/api/configuration.py b/chromadb/api/configuration.py index 0e2c487f915..ebc1c9ddcef 100644 --- a/chromadb/api/configuration.py +++ b/chromadb/api/configuration.py @@ -26,6 +26,12 @@ class StaticParameterError(Exception): pass +class InvalidConfigurationError(ValueError): + """Represents an error that occurs when a configuration is invalid.""" + + pass + + ParameterValue = Union[str, int, float, bool, "ConfigurationInternal"] @@ -110,8 +116,8 @@ def __init__(self, parameters: Optional[List[ConfigurationParameter]] = None): if not isinstance(parameter.value, type(definition.default_value)): raise ValueError(f"Invalid parameter value: {parameter.value}") - validator = definition.validator - if not validator(parameter.value): + parameter_validator = definition.validator + if not parameter_validator(parameter.value): raise ValueError(f"Invalid parameter value: {parameter.value}") self.parameter_map[parameter.name] = parameter # Apply the defaults for any missing parameters @@ -121,6 +127,8 @@ def __init__(self, parameters: Optional[List[ConfigurationParameter]] = None): name=name, value=definition.default_value ) + self.configuration_validator() + def __repr__(self) -> str: return f"Configuration({self.parameter_map.values()})" @@ -129,6 +137,14 @@ def __eq__(self, __value: object) -> bool: return NotImplemented return self.parameter_map == __value.parameter_map + @abstractmethod + def configuration_validator(self) -> None: + """Perform custom validation when parameters are dependent on each other. + + Raises an InvalidConfigurationError if the configuration is invalid. + """ + pass + def get_parameters(self) -> List[ConfigurationParameter]: """Returns the parameters of the configuration.""" return list(self.parameter_map.values()) @@ -247,16 +263,30 @@ class HNSWConfigurationInternal(ConfigurationInternal): name="batch_size", validator=lambda value: isinstance(value, int) and value >= 1, is_static=True, - default_value=1000, + default_value=100, ), "sync_threshold": ConfigurationDefinition( name="sync_threshold", validator=lambda value: isinstance(value, int) and value >= 1, is_static=True, - default_value=100, + default_value=1000, ), } + @override + def configuration_validator(self) -> None: + batch_size = self.parameter_map.get("batch_size") + sync_threshold = self.parameter_map.get("sync_threshold") + + if ( + batch_size + and sync_threshold + and cast(int, batch_size.value) > cast(int, sync_threshold.value) + ): + raise InvalidConfigurationError( + "batch_size must be less than or equal to sync_threshold" + ) + @classmethod def from_legacy_params(cls, params: Dict[str, Any]) -> Self: """Returns an HNSWConfiguration from a metadata dict containing legacy HNSW parameters. Used for migration.""" @@ -302,8 +332,8 @@ def __init__( num_threads: int = cpu_count(), M: int = 16, resize_factor: float = 1.2, - batch_size: int = 1000, - sync_threshold: int = 100, + batch_size: int = 100, + sync_threshold: int = 1000, ): parameters = [ ConfigurationParameter(name="space", value=space), @@ -336,6 +366,10 @@ class CollectionConfigurationInternal(ConfigurationInternal): ), } + @override + def configuration_validator(self) -> None: + pass + # This is the user-facing interface for HNSW index configuration parameters. # Internally, we pass around HNSWConfigurationInternal objects, which perform diff --git a/chromadb/db/mixins/sysdb.py b/chromadb/db/mixins/sysdb.py index 2684441cbb0..0c91903afa2 100644 --- a/chromadb/db/mixins/sysdb.py +++ b/chromadb/db/mixins/sysdb.py @@ -1,3 +1,4 @@ +import json from typing import Optional, Sequence, Any, Tuple, cast, Dict, Union, Set from uuid import UUID from overrides import override @@ -8,6 +9,7 @@ CollectionConfigurationInternal, ConfigurationParameter, HNSWConfigurationInternal, + InvalidConfigurationError, ) from chromadb.config import DEFAULT_DATABASE, DEFAULT_TENANT, System from chromadb.db.base import ( @@ -435,8 +437,8 @@ def get_collections( metadata = self._metadata_from_rows(rows) dimension = int(rows[0][3]) if rows[0][3] else None if rows[0][2] is not None: - configuration = CollectionConfigurationInternal.from_json_str( - rows[0][2] + configuration = self._load_config_from_json_str_and_migrate( + str(collection_id), rows[0][2] ) else: # 07/2024: This is a legacy case where we don't have a collection @@ -764,6 +766,56 @@ def _insert_metadata( if sql: cur.execute(sql, params) + def _load_config_from_json_str_and_migrate( + self, collection_id: str, json_str: str + ) -> CollectionConfigurationInternal: + try: + config_json = json.loads(json_str) + except json.JSONDecodeError: + raise ValueError( + f"Unable to decode configuration from JSON string: {json_str}" + ) + + try: + return CollectionConfigurationInternal.from_json_str(json_str) + except InvalidConfigurationError as error: + # 07/17/2024: the initial migration from the legacy metadata-based config to the new sysdb-based config had a bug where the batch_size and sync_threshold were swapped. Along with this migration, a validator was added to HNSWConfigurationInternal to ensure that batch_size <= sync_threshold. + hnsw_configuration = config_json.get("hnsw_configuration") + if hnsw_configuration: + batch_size = hnsw_configuration.get("batch_size") + sync_threshold = hnsw_configuration.get("sync_threshold") + + if batch_size and sync_threshold and batch_size > sync_threshold: + # Allow new defaults to be set + hnsw_configuration = { + k: v + for k, v in hnsw_configuration.items() + if k not in ["batch_size", "sync_threshold"] + } + config_json.update({"hnsw_configuration": hnsw_configuration}) + + configuration = CollectionConfigurationInternal.from_json( + config_json + ) + + collections_t = Table("collections") + q = ( + self.querybuilder() + .update(collections_t) + .set( + collections_t.config_json_str, + ParameterValue(configuration.to_json_str()), + ) + .where(collections_t.id == ParameterValue(collection_id)) + ) + sql, params = get_sql(q, self.parameter_format()) + with self.tx() as cur: + cur.execute(sql, params) + + return configuration + + raise error + def _insert_config_from_legacy_params( self, collection_id: Any, metadata: Optional[Metadata] ) -> CollectionConfigurationInternal: diff --git a/chromadb/test/configurations/test_configurations.py b/chromadb/test/configurations/test_configurations.py index 012f1cc939c..0d952957a73 100644 --- a/chromadb/test/configurations/test_configurations.py +++ b/chromadb/test/configurations/test_configurations.py @@ -1,9 +1,12 @@ +from overrides import overrides import pytest from chromadb.api.configuration import ( ConfigurationInternal, ConfigurationDefinition, + InvalidConfigurationError, StaticParameterError, ConfigurationParameter, + HNSWConfiguration, ) @@ -23,6 +26,10 @@ class TestConfiguration(ConfigurationInternal): ), } + @overrides + def configuration_validator(self) -> None: + pass + def test_default_values() -> None: default_test_configuration = TestConfiguration() @@ -76,3 +83,28 @@ def test_validation() -> None: ] with pytest.raises(ValueError): TestConfiguration(parameters=invalid_parameter_names) + + +def test_configuration_validation() -> None: + class FooConfiguration(ConfigurationInternal): + definitions = { + "foo": ConfigurationDefinition( + name="foo", + validator=lambda value: isinstance(value, str), + is_static=False, + default_value="default", + ), + } + + @overrides + def configuration_validator(self) -> None: + if self.parameter_map.get("foo") != "bar": + raise InvalidConfigurationError("foo must be 'bar'") + + with pytest.raises(ValueError, match="foo must be 'bar'"): + FooConfiguration(parameters=[ConfigurationParameter(name="foo", value="baz")]) + + +def test_hnsw_validation() -> None: + with pytest.raises(ValueError, match="must be less than or equal"): + HNSWConfiguration(batch_size=500, sync_threshold=100) diff --git a/chromadb/test/property/strategies.py b/chromadb/test/property/strategies.py index 77f8525310d..a0db1ca6ced 100644 --- a/chromadb/test/property/strategies.py +++ b/chromadb/test/property/strategies.py @@ -306,12 +306,17 @@ def collections( metadata = {} metadata.update(test_hnsw_config) if use_persistent_hnsw_params: - metadata["hnsw:batch_size"] = draw( - st.integers(min_value=3, max_value=max_hnsw_batch_size) - ) metadata["hnsw:sync_threshold"] = draw( st.integers(min_value=3, max_value=max_hnsw_sync_threshold) ) + metadata["hnsw:batch_size"] = draw( + st.integers( + min_value=3, + max_value=min( + [metadata["hnsw:sync_threshold"], max_hnsw_batch_size] + ), + ) + ) # Sometimes, select a space at random if draw(st.booleans()): # TODO: pull the distance functions from a source of truth that lives not