Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] fix HNSW param defaults in new configuration logic & require batch_size < sync_threshold #2526

Merged
merged 8 commits into from
Jul 18, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 32 additions & 4 deletions chromadb/api/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
List,
Optional,
Protocol,
Tuple,
Union,
TypeVar,
cast,
Expand Down Expand Up @@ -121,6 +122,12 @@ def __init__(self, parameters: Optional[List[ConfigurationParameter]] = None):
name=name, value=definition.default_value
)

(is_valid, error_msg) = self.validator()
codetheweb marked this conversation as resolved.
Show resolved Hide resolved
if not is_valid:
if error_msg:
raise ValueError(f"Invalid configuration: {error_msg}")
raise ValueError("Invalid configuration")

def __repr__(self) -> str:
return f"Configuration({self.parameter_map.values()})"

Expand All @@ -129,6 +136,13 @@ def __eq__(self, __value: object) -> bool:
return NotImplemented
return self.parameter_map == __value.parameter_map

def validator(self) -> Tuple[bool, Optional[str]]:
codetheweb marked this conversation as resolved.
Show resolved Hide resolved
"""Perform custom validation when parameters are dependent on each other.

Returns a tuple with a boolean indicating whether the configuration is valid and an optional error message.
codetheweb marked this conversation as resolved.
Show resolved Hide resolved
"""
return (True, None)

def get_parameters(self) -> List[ConfigurationParameter]:
"""Returns the parameters of the configuration."""
return list(self.parameter_map.values())
Expand Down Expand Up @@ -247,16 +261,30 @@ class HNSWConfigurationInternal(ConfigurationInternal):
name="batch_size",
validator=lambda value: isinstance(value, int) and value >= 1,
is_static=True,
default_value=1000,
default_value=100,
),
"sync_threshold": ConfigurationDefinition(
name="sync_threshold",
validator=lambda value: isinstance(value, int) and value >= 1,
is_static=True,
default_value=100,
default_value=1000,
),
}

@override
def validator(self) -> Tuple[bool, Optional[str]]:
batch_size = self.parameter_map.get("batch_size")
sync_threshold = self.parameter_map.get("sync_threshold")

if (
batch_size
and sync_threshold
and cast(int, batch_size.value) > cast(int, sync_threshold.value)
):
return (False, "batch_size must be less than or equal to sync_threshold")

return super().validator()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can just return nothing if we don't except, if we make this something which raises exceptions but otherwise does nothing

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this can probably take a parameter list?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this can probably take a parameter list?

sorry, what do you mean by this?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

right now this is an instance method that reads self.parameter_map but what we actually want is to see if a given parameter_list on creation corresponds to a valid configuration, similar to how ParameterValidator works for individual values.

This is mostly a stylistic choice though and I think it's fine.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, it does kinda make sense for it to be a static method but finding the params you want from the list would be a little annoying


@classmethod
def from_legacy_params(cls, params: Dict[str, Any]) -> Self:
"""Returns an HNSWConfiguration from a metadata dict containing legacy HNSW parameters. Used for migration."""
Expand Down Expand Up @@ -302,8 +330,8 @@ def __init__(
num_threads: int = cpu_count(),
M: int = 16,
resize_factor: float = 1.2,
batch_size: int = 1000,
sync_threshold: int = 100,
batch_size: int = 100,
sync_threshold: int = 1000,
atroyn marked this conversation as resolved.
Show resolved Hide resolved
):
parameters = [
ConfigurationParameter(name="space", value=space),
Expand Down
50 changes: 48 additions & 2 deletions chromadb/db/mixins/sysdb.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
from typing import Optional, Sequence, Any, Tuple, cast, Dict, Union, Set
from uuid import UUID
from overrides import override
Expand Down Expand Up @@ -435,8 +436,8 @@ def get_collections(
metadata = self._metadata_from_rows(rows)
dimension = int(rows[0][3]) if rows[0][3] else None
if rows[0][2] is not None:
configuration = CollectionConfigurationInternal.from_json_str(
rows[0][2]
configuration = self._load_config_from_json_str_and_migrate(
str(collection_id), rows[0][2]
)
else:
# 07/2024: This is a legacy case where we don't have a collection
Expand Down Expand Up @@ -764,6 +765,51 @@ def _insert_metadata(
if sql:
cur.execute(sql, params)

def _load_config_from_json_str_and_migrate(
self, collection_id: str, json_str: str
) -> CollectionConfigurationInternal:
try:
config_json = json.loads(json_str)
except json.JSONDecodeError:
raise ValueError(
f"Unable to decode configuration from JSON string: {json_str}"
)

# 07/17/2024: the initial migration from the legacy metadata-based config to the new sysdb-based config had a bug where the batch_size and sync_threshold were swapped. Along with this migration, a validator was added to HNSWConfigurationInternal to ensure that batch_size <= sync_threshold.
hnsw_configuration = config_json.get("hnsw_configuration")
codetheweb marked this conversation as resolved.
Show resolved Hide resolved
if hnsw_configuration:
batch_size = hnsw_configuration.get("batch_size")
sync_threshold = hnsw_configuration.get("sync_threshold")

if batch_size and sync_threshold and batch_size > sync_threshold:
codetheweb marked this conversation as resolved.
Show resolved Hide resolved
config_json["hnsw_configuration"][
"sync_threshold"
] = HNSWConfigurationInternal.definitions[
"sync_threshold"
].default_value
config_json["hnsw_configuration"][
"batch_size"
] = HNSWConfigurationInternal.definitions["batch_size"].default_value
configuration = CollectionConfigurationInternal.from_json(config_json)

collections_t = Table("collections")
q = (
self.querybuilder()
.update(collections_t)
.set(
collections_t.config_json_str,
ParameterValue(configuration.to_json_str()),
)
.where(collections_t.id == ParameterValue(collection_id))
)
sql, params = get_sql(q, self.parameter_format())
with self.tx() as cur:
cur.execute(sql, params)

return configuration

return CollectionConfigurationInternal.from_json_str(json_str)
Copy link
Contributor Author

@codetheweb codetheweb Jul 17, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I considered making migrations a generic concept for CollectionConfigurations, but after starting to implement that it seemed like a lot of abstraction overhead for now

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah let's punt on that


def _insert_config_from_legacy_params(
self, collection_id: Any, metadata: Optional[Metadata]
) -> CollectionConfigurationInternal:
Expand Down
6 changes: 6 additions & 0 deletions chromadb/test/configurations/test_configurations.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
ConfigurationDefinition,
StaticParameterError,
ConfigurationParameter,
HNSWConfiguration,
)


Expand Down Expand Up @@ -76,3 +77,8 @@ def test_validation() -> None:
]
with pytest.raises(ValueError):
TestConfiguration(parameters=invalid_parameter_names)


def test_hnsw_validation() -> None:
codetheweb marked this conversation as resolved.
Show resolved Hide resolved
with pytest.raises(ValueError, match="must be less than or equal"):
HNSWConfiguration(batch_size=500, sync_threshold=100)
11 changes: 8 additions & 3 deletions chromadb/test/property/strategies.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,12 +306,17 @@ def collections(
metadata = {}
metadata.update(test_hnsw_config)
if use_persistent_hnsw_params:
metadata["hnsw:batch_size"] = draw(
st.integers(min_value=3, max_value=max_hnsw_batch_size)
)
metadata["hnsw:sync_threshold"] = draw(
st.integers(min_value=3, max_value=max_hnsw_sync_threshold)
)
metadata["hnsw:batch_size"] = draw(
st.integers(
min_value=3,
max_value=min(
[metadata["hnsw:sync_threshold"], max_hnsw_batch_size]
),
)
)
# Sometimes, select a space at random
if draw(st.booleans()):
# TODO: pull the distance functions from a source of truth that lives not
Expand Down
Loading