diff --git a/.github/workflows/chroma-cluster-test.yml b/.github/workflows/chroma-cluster-test.yml index 64209dbc21d..3935156da21 100644 --- a/.github/workflows/chroma-cluster-test.yml +++ b/.github/workflows/chroma-cluster-test.yml @@ -18,7 +18,8 @@ jobs: platform: ['16core-64gb-ubuntu-latest'] testfile: ["chromadb/test/db/test_system.py", "chromadb/test/ingest/test_producer_consumer.py", - "chromadb/test/segment/distributed/test_memberlist_provider.py"] + "chromadb/test/segment/distributed/test_memberlist_provider.py", + "chromadb/test/test_logservice.py"] runs-on: ${{ matrix.platform }} steps: - name: Checkout @@ -65,4 +66,4 @@ jobs: - name: Start Tilt run: tilt ci - name: Test - run: bin/cluster-test.sh bash -c 'cd go && go test -timeout 30s -run ^TestNodeWatcher$ github.com/chroma/chroma-coordinator/internal/memberlist_manager' \ No newline at end of file + run: bin/cluster-test.sh bash -c 'cd go && go test -timeout 30s -run ^TestNodeWatcher$ github.com/chroma/chroma-coordinator/internal/memberlist_manager' diff --git a/Tiltfile b/Tiltfile index c54e95d5eed..1b1ba96b31f 100644 --- a/Tiltfile +++ b/Tiltfile @@ -47,6 +47,8 @@ k8s_resource( 'coordinator-serviceaccount-rolebinding:RoleBinding', 'coordinator-worker-memberlist-binding:clusterrolebinding', + 'logservice-serviceaccount:serviceaccount', + 'worker-serviceaccount:serviceaccount', 'worker-serviceaccount-rolebinding:RoleBinding', 'worker-memberlist-readerwriter:ClusterRole', @@ -65,14 +67,15 @@ k8s_resource( k8s_resource('postgres', resource_deps=['k8s_setup'], labels=["infrastructure"]) k8s_resource('pulsar', resource_deps=['k8s_setup'], labels=["infrastructure"], port_forwards=['6650:6650', '8080:8080']) k8s_resource('migration', resource_deps=['postgres'], labels=["infrastructure"]) -k8s_resource('logservice', resource_deps=['migration'], labels=["chroma"]) -k8s_resource('frontend-server', resource_deps=['pulsar'],labels=["chroma"], port_forwards=8000 ) +k8s_resource('logservice', resource_deps=['migration'], labels=["chroma"], port_forwards='50052:50051') +k8s_resource('frontend-server', resource_deps=['logservice'],labels=["chroma"], port_forwards=8000 ) k8s_resource('coordinator', resource_deps=['pulsar', 'frontend-server', 'migration'], labels=["chroma"], port_forwards=50051) k8s_resource('worker', resource_deps=['coordinator'],labels=["chroma"]) # Extra stuff to make debugging and testing easier k8s_yaml([ 'k8s/test/coordinator_service.yaml', + 'k8s/test/logservice_service.yaml', 'k8s/test/minio.yaml', 'k8s/test/pulsar_service.yaml', 'k8s/test/worker_service.yaml', @@ -90,4 +93,4 @@ k8s_resource( ) # Local S3 -k8s_resource('minio-deployment', resource_deps=['k8s_setup'], labels=["debug"], port_forwards=9000) \ No newline at end of file +k8s_resource('minio-deployment', resource_deps=['k8s_setup'], labels=["debug"], port_forwards=9000) diff --git a/bin/cluster-test.sh b/bin/cluster-test.sh index 75716d769a9..375ca464d00 100755 --- a/bin/cluster-test.sh +++ b/bin/cluster-test.sh @@ -12,6 +12,7 @@ echo "Pulsar Broker is running at port $PULSAR_BROKER_URL" echo "Chroma Coordinator is running at port $CHROMA_COORDINATOR_HOST" kubectl -n chroma port-forward svc/coordinator-lb 50051:50051 & +kubectl -n chroma port-forward svc/logservice-lb 50052:50051 & kubectl -n chroma port-forward svc/pulsar-lb 6650:6650 & kubectl -n chroma port-forward svc/pulsar-lb 8080:8080 & kubectl -n chroma port-forward svc/frontend-server 8000:8000 & diff --git a/chromadb/config.py b/chromadb/config.py index bc8234bc34d..9f9c2f50e45 100644 --- a/chromadb/config.py +++ b/chromadb/config.py @@ -76,12 +76,13 @@ "chromadb.segment.SegmentManager": "chroma_segment_manager_impl", "chromadb.segment.distributed.SegmentDirectory": "chroma_segment_directory_impl", "chromadb.segment.distributed.MemberlistProvider": "chroma_memberlist_provider_impl", - "chromadb.rate_limiting.RateLimitingProvider": "chroma_rate_limiting_provider_impl" + "chromadb.rate_limiting.RateLimitingProvider": "chroma_rate_limiting_provider_impl", } DEFAULT_TENANT = "default_tenant" DEFAULT_DATABASE = "default_database" + class Settings(BaseSettings): # type: ignore environment: str = "" @@ -101,8 +102,10 @@ class Settings(BaseSettings): # type: ignore chroma_segment_manager_impl: str = ( "chromadb.segment.impl.manager.local.LocalSegmentManager" ) - chroma_quota_provider_impl:Optional[str] = None - chroma_rate_limiting_provider_impl:Optional[str] = None + + chroma_quota_provider_impl: Optional[str] = None + chroma_rate_limiting_provider_impl: Optional[str] = None + # Distributed architecture specific components chroma_segment_directory_impl: str = "chromadb.segment.impl.distributed.segment_directory.RendezvousHashSegmentDirectory" chroma_memberlist_provider_impl: str = "chromadb.segment.impl.distributed.segment_directory.CustomResourceMemberlistProvider" @@ -112,6 +115,9 @@ class Settings(BaseSettings): # type: ignore worker_memberlist_name: str = "worker-memberlist" chroma_coordinator_host = "localhost" + chroma_logservice_host = "localhost" + chroma_logservice_port = 50052 + tenant_id: str = "default" topic_namespace: str = "default" @@ -320,7 +326,10 @@ def __init__(self, settings: Settings): if settings[key] is not None: raise ValueError(LEGACY_ERROR) - if settings["chroma_segment_cache_policy"] is not None and settings["chroma_segment_cache_policy"] != "LRU": + if ( + settings["chroma_segment_cache_policy"] is not None + and settings["chroma_segment_cache_policy"] != "LRU" + ): logger.error( f"Failed to set chroma_segment_cache_policy: Only LRU is available." ) diff --git a/chromadb/logservice/logservice.py b/chromadb/logservice/logservice.py new file mode 100644 index 00000000000..6b95c469500 --- /dev/null +++ b/chromadb/logservice/logservice.py @@ -0,0 +1,171 @@ +import sys + +import grpc + +from chromadb.ingest import ( + Producer, + Consumer, + ConsumerCallbackFn, +) +from chromadb.proto.chroma_pb2 import ( + SubmitEmbeddingRecord as ProtoSubmitEmbeddingRecord, +) +from chromadb.proto.convert import to_proto_submit +from chromadb.proto.logservice_pb2 import PushLogsRequest, PullLogsRequest +from chromadb.proto.logservice_pb2_grpc import LogServiceStub +from chromadb.types import ( + SubmitEmbeddingRecord, + SeqId, +) +from chromadb.config import System +from chromadb.telemetry.opentelemetry import ( + OpenTelemetryClient, + OpenTelemetryGranularity, + trace_method, +) +from overrides import override +from typing import Sequence, Optional, Dict, cast +from uuid import UUID +import logging + +logger = logging.getLogger(__name__) + + +class LogService(Producer, Consumer): + """ + Distributed Chroma Log Service + """ + + _log_service_stub: LogServiceStub + _channel: grpc.Channel + _log_service_url: str + _log_service_port: int + + def __init__(self, system: System): + self._log_service_url = system.settings.require("chroma_logservice_host") + self._log_service_port = system.settings.require("chroma_logservice_port") + self._opentelemetry_client = system.require(OpenTelemetryClient) + super().__init__(system) + + @trace_method("LogService.start", OpenTelemetryGranularity.ALL) + @override + def start(self) -> None: + self._channel = grpc.insecure_channel( + f"{self._log_service_url}:{self._log_service_port}" + ) + self._log_service_stub = LogServiceStub(self._channel) # type: ignore + super().start() + + @trace_method("LogService.stop", OpenTelemetryGranularity.ALL) + @override + def stop(self) -> None: + self._channel.close() + super().stop() + + @trace_method("LogService.reset_state", OpenTelemetryGranularity.ALL) + @override + def reset_state(self) -> None: + super().reset_state() + + @override + def create_topic(self, topic_name: str) -> None: + raise NotImplementedError("Not implemented") + + @trace_method("LogService.delete_topic", OpenTelemetryGranularity.ALL) + @override + def delete_topic(self, topic_name: str) -> None: + raise NotImplementedError("Not implemented") + + @trace_method("LogService.submit_embedding", OpenTelemetryGranularity.ALL) + @override + def submit_embedding( + self, topic_name: str, embedding: SubmitEmbeddingRecord + ) -> SeqId: + if not self._running: + raise RuntimeError("Component not running") + + return self.submit_embeddings(topic_name, [embedding])[0] # type: ignore + + @trace_method("LogService.submit_embeddings", OpenTelemetryGranularity.ALL) + @override + def submit_embeddings( + self, topic_name: str, embeddings: Sequence[SubmitEmbeddingRecord] + ) -> Sequence[SeqId]: + logger.info(f"Submitting {len(embeddings)} embeddings to {topic_name}") + + if not self._running: + raise RuntimeError("Component not running") + + if len(embeddings) == 0: + return [] + + # push records to the log service + collection_id_to_embeddings: Dict[UUID, list[SubmitEmbeddingRecord]] = {} + for embedding in embeddings: + collection_id = cast(UUID, embedding.get("collection_id")) + if collection_id is None: + raise ValueError("collection_id is required") + if collection_id not in collection_id_to_embeddings: + collection_id_to_embeddings[collection_id] = [] + collection_id_to_embeddings[collection_id].append(embedding) + + counts = [] + for collection_id, records in collection_id_to_embeddings.items(): + protos_to_submit = [to_proto_submit(record) for record in records] + counts.append( + self.push_logs( + collection_id, + cast(Sequence[SubmitEmbeddingRecord], protos_to_submit), + ) + ) + + return counts + + @trace_method("LogService.subscribe", OpenTelemetryGranularity.ALL) + @override + def subscribe( + self, + topic_name: str, + consume_fn: ConsumerCallbackFn, + start: Optional[SeqId] = None, + end: Optional[SeqId] = None, + id: Optional[UUID] = None, + ) -> UUID: + logger.info(f"Subscribing to {topic_name}, noop for logservice") + return UUID(int=0) + + @trace_method("LogService.unsubscribe", OpenTelemetryGranularity.ALL) + @override + def unsubscribe(self, subscription_id: UUID) -> None: + logger.info(f"Unsubscribing from {subscription_id}, noop for logservice") + + @override + def min_seqid(self) -> SeqId: + return 0 + + @override + def max_seqid(self) -> SeqId: + return sys.maxsize + + @property + @override + def max_batch_size(self) -> int: + return sys.maxsize + + def push_logs( + self, collection_id: UUID, records: Sequence[SubmitEmbeddingRecord] + ) -> int: + request = PushLogsRequest(collection_id=str(collection_id), records=records) + response = self._log_service_stub.PushLogs(request) + return response.record_count # type: ignore + + def pull_logs( + self, collection_id: UUID, start_id: int, batch_size: int + ) -> Sequence[ProtoSubmitEmbeddingRecord]: + request = PullLogsRequest( + collection_id=str(collection_id), + start_from_id=start_id, + batch_size=batch_size, + ) + response = self._log_service_stub.PullLogs(request) + return response.records # type: ignore diff --git a/chromadb/test/conftest.py b/chromadb/test/conftest.py index 4e55ffc6749..8a8cd979072 100644 --- a/chromadb/test/conftest.py +++ b/chromadb/test/conftest.py @@ -287,6 +287,7 @@ def basic_http_client() -> Generator[System, None, None]: settings = Settings( chroma_api_impl="chromadb.api.fastapi.FastAPI", chroma_server_http_port=8000, + chroma_server_host="localhost", allow_reset=True, ) system = System(settings) @@ -468,6 +469,7 @@ def system_wrong_auth( def system(request: pytest.FixtureRequest) -> Generator[ServerAPI, None, None]: yield next(request.param()) + @pytest.fixture(scope="module", params=system_fixtures_ssl()) def system_ssl(request: pytest.FixtureRequest) -> Generator[ServerAPI, None, None]: yield next(request.param()) diff --git a/chromadb/test/test_logservice.py b/chromadb/test/test_logservice.py new file mode 100644 index 00000000000..9a96426b99b --- /dev/null +++ b/chromadb/test/test_logservice.py @@ -0,0 +1,154 @@ +import array +from typing import Dict, Any, Callable + +from chromadb.config import System, Settings +from chromadb.logservice.logservice import LogService +from chromadb.test.conftest import skip_if_not_cluster +from chromadb.test.test_api import records # type: ignore +from chromadb.api.models.Collection import Collection + +batch_records = { + "embeddings": [[1.1, 2.3, 3.2], [1.2, 2.24, 3.2]], + "ids": ["https://example.com/1", "https://example.com/2"], +} + +metadata_records = { + "embeddings": [[1.1, 2.3, 3.2], [1.2, 2.24, 3.2]], + "ids": ["id1", "id2"], + "metadatas": [ + {"int_value": 1, "string_value": "one", "float_value": 1.001}, + {"int_value": 2}, + ], +} + +contains_records = { + "embeddings": [[1.1, 2.3, 3.2], [1.2, 2.24, 3.2]], + "documents": ["this is doc1 and it's great!", "doc2 is also great!"], + "ids": ["id1", "id2"], + "metadatas": [ + {"int_value": 1, "string_value": "one", "float_value": 1.001}, + {"int_value": 2, "float_value": 2.002, "string_value": "two"}, + ], +} + + +def verify_records( + logservice: LogService, + collection: Collection, + test_records_map: Dict[str, Dict[str, Any]], + test_func: Callable, # type: ignore + operation: int, +) -> None: + start_id = 1 + for batch_records in test_records_map.values(): + test_func(**batch_records) + pushed_records = logservice.pull_logs(collection.id, start_id, 100) + assert len(pushed_records) == len(batch_records["ids"]) + for i, record in enumerate(pushed_records): + assert record.id == batch_records["ids"][i] + assert record.operation == operation + embedding = array.array("f", batch_records["embeddings"][i]).tobytes() + assert record.vector.vector == embedding + metadata_count = 0 + if "metadatas" in batch_records: + metadata_count += len(batch_records["metadatas"][i]) + for key, value in batch_records["metadatas"][i].items(): + if isinstance(value, int): + assert record.metadata.metadata[key].int_value == value + elif isinstance(value, float): + assert record.metadata.metadata[key].float_value == value + elif isinstance(value, str): + assert record.metadata.metadata[key].string_value == value + else: + assert False + if "documents" in batch_records: + metadata_count += 1 + assert ( + record.metadata.metadata["chroma:document"].string_value + == batch_records["documents"][i] + ) + assert len(record.metadata.metadata) == metadata_count + start_id += len(pushed_records) + + +@skip_if_not_cluster() +def test_add(api): # type: ignore + system = System(Settings(allow_reset=True)) + logservice = system.instance(LogService) + system.start() + api.reset() + + test_records_map = { + "batch_records": batch_records, + "metadata_records": metadata_records, + "contains_records": contains_records, + } + + collection = api.create_collection("testadd") + verify_records(logservice, collection, test_records_map, collection.add, 0) + + +@skip_if_not_cluster() +def test_update(api): # type: ignore + system = System(Settings(allow_reset=True)) + logservice = system.instance(LogService) + system.start() + api.reset() + + test_records_map = { + "updated_records": { + "ids": [records["ids"][0]], + "embeddings": [[0.1, 0.2, 0.3]], + "metadatas": [{"foo": "bar"}], + }, + } + + collection = api.create_collection("testupdate") + verify_records(logservice, collection, test_records_map, collection.update, 1) + + +@skip_if_not_cluster() +def test_delete(api): # type: ignore + system = System(Settings(allow_reset=True)) + logservice = system.instance(LogService) + system.start() + api.reset() + + collection = api.create_collection("testdelete") + + # push 2 records + collection.add(**contains_records) + pushed_records = logservice.pull_logs(collection.id, 1, 100) + assert len(pushed_records) == 2 + + # delete by where does not work atm + collection.delete(where_document={"$contains": "doc1"}) + collection.delete(where_document={"$contains": "bad"}) + collection.delete(where_document={"$contains": "great"}) + pushed_records = logservice.pull_logs(collection.id, 3, 100) + assert len(pushed_records) == 0 + + # delete by ids + collection.delete(ids=["id1", "id2"]) + pushed_records = logservice.pull_logs(collection.id, 3, 100) + assert len(pushed_records) == 2 + for record in pushed_records: + assert record.operation == 3 + assert record.id in ["id1", "id2"] + + +@skip_if_not_cluster() +def test_upsert(api): # type: ignore + system = System(Settings(allow_reset=True)) + logservice = system.instance(LogService) + system.start() + api.reset() + + test_records_map = { + "batch_records": batch_records, + "metadata_records": metadata_records, + "contains_records": contains_records, + } + + collection = api.create_collection("testupsert") + verify_records(logservice, collection, test_records_map, collection.upsert, 2) diff --git a/go/internal/coordinator/grpc/collection_service.go b/go/internal/coordinator/grpc/collection_service.go index f3d7ce94f28..2e78b0772f0 100644 --- a/go/internal/coordinator/grpc/collection_service.go +++ b/go/internal/coordinator/grpc/collection_service.go @@ -17,6 +17,7 @@ const successCode = 200 const success = "ok" func (s *Server) ResetState(context.Context, *emptypb.Empty) (*coordinatorpb.ResetStateResponse, error) { + log.Info("reset state") res := &coordinatorpb.ResetStateResponse{} err := s.coordinator.ResetState(context.Background()) if err != nil { diff --git a/go/internal/coordinator/grpc/server.go b/go/internal/coordinator/grpc/server.go index 578298719a7..8f3c6c57624 100644 --- a/go/internal/coordinator/grpc/server.go +++ b/go/internal/coordinator/grpc/server.go @@ -94,11 +94,11 @@ func NewWithGrpcProvider(config Config, provider grpcutils.GrpcProvider, db *gor assignmentPolicy = coordinator.NewSimpleAssignmentPolicy(config.PulsarTenant, config.PulsarNamespace) } else if config.AssignmentPolicy == "rendezvous" { log.Info("Using rendezvous assignment policy") - err := utils.CreateTopics(config.PulsarAdminURL, config.PulsarTenant, config.PulsarNamespace, coordinator.Topics[:]) - if err != nil { - log.Error("Failed to create topics", zap.Error(err)) - return nil, err - } + //err := utils.CreateTopics(config.PulsarAdminURL, config.PulsarTenant, config.PulsarNamespace, coordinator.Topics[:]) + //if err != nil { + // log.Error("Failed to create topics", zap.Error(err)) + // return nil, err + //} assignmentPolicy = coordinator.NewRendezvousAssignmentPolicy(config.PulsarTenant, config.PulsarNamespace) } else { return nil, errors.New("invalid assignment policy, only simple and rendezvous are supported") diff --git a/k8s/distributed-chroma/templates/coordinator.yaml b/k8s/distributed-chroma/templates/coordinator.yaml index 880ed130a4e..9b9dbf9c44e 100644 --- a/k8s/distributed-chroma/templates/coordinator.yaml +++ b/k8s/distributed-chroma/templates/coordinator.yaml @@ -25,7 +25,7 @@ spec: imagePullPolicy: IfNotPresent name: coordinator ports: - - containerPort: 50001 + - containerPort: 50051 name: grpc --- @@ -38,7 +38,7 @@ metadata: spec: ports: - name: grpc - port: 50001 + port: 50051 targetPort: grpc selector: app: coordinator @@ -68,4 +68,4 @@ subjects: name: coordinator-serviceaccount namespace: {{ .Values.namespace }} ---- \ No newline at end of file +--- diff --git a/k8s/distributed-chroma/templates/frontend-server.yaml b/k8s/distributed-chroma/templates/frontend-server.yaml index 39678d78d11..ddddd7679ae 100644 --- a/k8s/distributed-chroma/templates/frontend-server.yaml +++ b/k8s/distributed-chroma/templates/frontend-server.yaml @@ -26,9 +26,9 @@ spec: - name: IS_PERSISTENT value: "TRUE" - name: CHROMA_PRODUCER_IMPL - value: "chromadb.ingest.impl.pulsar.PulsarProducer" + value: "chromadb.logservice.logservice.LogService" - name: CHROMA_CONSUMER_IMPL - value: "chromadb.ingest.impl.pulsar.PulsarConsumer" + value: "chromadb.logservice.logservice.LogService" - name: CHROMA_SEGMENT_MANAGER_IMPL value: "chromadb.segment.impl.manager.distributed.DistributedSegmentManager" - name: PULSAR_BROKER_URL @@ -45,6 +45,12 @@ spec: value: "50051" - name: CHROMA_COORDINATOR_HOST value: "coordinator.chroma" + - name: CHROMA_MEMBERLIST_PROVIDER_IMPL + value: "chromadb.segment.impl.distributed.segment_directory.MockMemberlistProvider" + - name: CHROMA_LOGSERVICE_HOST + value: "logservice.chroma" + - name: CHROMA_LOGSERVICE_PORT + value: "50051" volumes: - name: chroma emptyDir: {} diff --git a/k8s/distributed-chroma/templates/logservice.yaml b/k8s/distributed-chroma/templates/logservice.yaml index 113b0813e37..c72748b69a0 100644 --- a/k8s/distributed-chroma/templates/logservice.yaml +++ b/k8s/distributed-chroma/templates/logservice.yaml @@ -13,6 +13,7 @@ spec: labels: app: logservice spec: + serviceAccountName: logservice-serviceaccount containers: - command: - "logservice" @@ -37,3 +38,13 @@ spec: selector: app: logservice type: ClusterIP + +--- + +apiVersion: v1 +kind: ServiceAccount +metadata: + name: logservice-serviceaccount + namespace: {{ .Values.namespace }} + +--- \ No newline at end of file diff --git a/k8s/test/logservice_service.yaml b/k8s/test/logservice_service.yaml new file mode 100644 index 00000000000..2d3a4a8566a --- /dev/null +++ b/k8s/test/logservice_service.yaml @@ -0,0 +1,13 @@ +apiVersion: v1 +kind: Service +metadata: + name: logservice-lb + namespace: chroma +spec: + ports: + - name: grpc + port: 50051 + targetPort: 50051 + selector: + app: logservice + type: LoadBalancer