retryableCodes) {
- if (transactionId != null) {
- // Transactional queries are retried via the transaction runner.
- return false;
- }
-
- if (!isRetryableError(throwable, retryableCodes)) {
- return false;
- }
-
- if (rpcContext.getTotalRequestTimeout().isZero()) {
- return true;
- }
-
- Duration duration = Duration.ofNanos(rpcContext.getClock().nanoTime() - startTimeNanos);
- return duration.compareTo(rpcContext.getTotalRequestTimeout()) < 0;
- }
-
/**
* Returns a query that counts the documents in the result set of this query.
*
@@ -2128,6 +1826,171 @@ public AggregateQuery aggregate(
return new AggregateQuery(this, aggregateFieldList);
}
+ /**
+ * Returns a VectorQuery that can perform vector distance (similarity) search with given
+ * parameters.
+ *
+ * The returned query, when executed, performs a distance (similarity) search on the specified
+ * `vectorField` against the given `queryVector` and returns the top documents that are closest to
+ * the `queryVector`.
+ *
+ *
Only documents whose `vectorField` field is a {@link VectorValue} of the same dimension as
+ * `queryVector` participate in the query, all other documents are ignored.
+ *
+ *
{@code VectorQuery vectorQuery = col.findNearest("embedding", new double[] {41, 42}, 10,
+ * VectorQuery.DistanceMeasure.EUCLIDEAN); QuerySnapshot querySnapshot = await
+ * vectorQuery.get().get(); DocumentSnapshot mostSimilarDocument =
+ * querySnapshot.getDocuments().get(0);}
+ *
+ * @param vectorField A string specifying the vector field to search on.
+ * @param queryVector A representation of the vector used to measure the distance from
+ * `vectorField` values in the documents.
+ * @param limit The upper bound of documents to return, must be a positive integer with a maximum
+ * value of 1000.
+ * @param distanceMeasure What type of distance is calculated when performing the query. See
+ * {@link VectorQuery.DistanceMeasure}.
+ * @return an {@link VectorQuery} that performs vector distance (similarity) search with the given
+ * parameters.
+ */
+ public VectorQuery findNearest(
+ String vectorField,
+ double[] queryVector,
+ int limit,
+ VectorQuery.DistanceMeasure distanceMeasure) {
+ return findNearest(
+ FieldPath.fromDotSeparatedString(vectorField),
+ FieldValue.vector(queryVector),
+ limit,
+ distanceMeasure,
+ VectorQueryOptions.getDefaultInstance());
+ }
+
+ /**
+ * Returns a VectorQuery that can perform vector distance (similarity) search with given
+ * parameters.
+ *
+ *
The returned query, when executed, performs a distance (similarity) search on the specified
+ * `vectorField` against the given `queryVector` and returns the top documents that are closest to
+ * the `queryVector`.
+ *
+ *
Only documents whose `vectorField` field is a {@link VectorValue} of the same dimension as
+ * `queryVector` participate in the query, all other documents are ignored.
+ *
+ *
{@code VectorQuery vectorQuery = col.findNearest( "embedding", new double[] {41, 42}, 10,
+ * VectorQuery.DistanceMeasure.EUCLIDEAN,
+ * FindNearestOptions.newBuilder().setDistanceThreshold(0.11).setDistanceResultField("foo").build());
+ * QuerySnapshot querySnapshot = await vectorQuery.get().get(); DocumentSnapshot
+ * mostSimilarDocument = querySnapshot.getDocuments().get(0);}
+ *
+ * @param vectorField A string specifying the vector field to search on.
+ * @param queryVector A representation of the vector used to measure the distance from
+ * `vectorField` values in the documents.
+ * @param limit The upper bound of documents to return, must be a positive integer with a maximum
+ * value of 1000.
+ * @param distanceMeasure What type of distance is calculated when performing the query. See
+ * {@link VectorQuery.DistanceMeasure}.
+ * @param vectorQueryOptions Optional arguments for VectorQueries, see {@link VectorQueryOptions}.
+ * @return an {@link VectorQuery} that performs vector distance (similarity) search with the given
+ * parameters.
+ */
+ public VectorQuery findNearest(
+ String vectorField,
+ double[] queryVector,
+ int limit,
+ VectorQuery.DistanceMeasure distanceMeasure,
+ VectorQueryOptions vectorQueryOptions) {
+ return findNearest(
+ FieldPath.fromDotSeparatedString(vectorField),
+ FieldValue.vector(queryVector),
+ limit,
+ distanceMeasure,
+ vectorQueryOptions);
+ }
+
+ /**
+ * Returns a VectorQuery that can perform vector distance (similarity) search with given
+ * parameters.
+ *
+ *
The returned query, when executed, performs a distance (similarity) search on the specified
+ * `vectorField` against the given `queryVector` and returns the top documents that are closest to
+ * the `queryVector`.
+ *
+ *
Only documents whose `vectorField` field is a {@link VectorValue} of the same dimension as
+ * `queryVector` participate in the query, all other documents are ignored.
+ *
+ *
{@code VectorValue queryVector = FieldValue.vector(new double[] {41, 42}); VectorQuery
+ * vectorQuery = col.findNearest( FieldPath.of("embedding"), queryVector, 10,
+ * VectorQuery.DistanceMeasure.EUCLIDEAN); QuerySnapshot querySnapshot = await
+ * vectorQuery.get().get(); DocumentSnapshot mostSimilarDocument =
+ * querySnapshot.getDocuments().get(0);}
+ *
+ * @param vectorField A {@link FieldPath} specifying the vector field to search on.
+ * @param queryVector The {@link VectorValue} used to measure the distance from `vectorField`
+ * values in the documents.
+ * @param limit The upper bound of documents to return, must be a positive integer with a maximum
+ * value of 1000.
+ * @param distanceMeasure What type of distance is calculated when performing the query. See
+ * {@link VectorQuery.DistanceMeasure}.
+ * @return an {@link VectorQuery} that performs vector distance (similarity) search with the given
+ * parameters.
+ */
+ public VectorQuery findNearest(
+ FieldPath vectorField,
+ VectorValue queryVector,
+ int limit,
+ VectorQuery.DistanceMeasure distanceMeasure) {
+ return findNearest(
+ vectorField, queryVector, limit, distanceMeasure, VectorQueryOptions.getDefaultInstance());
+ }
+
+ /**
+ * Returns a VectorQuery that can perform vector distance (similarity) search with given
+ * parameters.
+ *
+ *
The returned query, when executed, performs a distance (similarity) search on the specified
+ * `vectorField` against the given `queryVector` and returns the top documents that are closest to
+ * the `queryVector`.
+ *
+ *
Only documents whose `vectorField` field is a {@link VectorValue} of the same dimension as
+ * `queryVector` participate in the query, all other documents are ignored.
+ *
+ *
{@code VectorQuery vectorQuery = col.findNearest( FieldPath.of("embedding"), queryVector,
+ * 10, VectorQuery.DistanceMeasure.EUCLIDEAN,
+ * FindNearestOptions.newBuilder().setDistanceThreshold(0.11).setDistanceResultField("foo").build());
+ * QuerySnapshot querySnapshot = await vectorQuery.get().get(); DocumentSnapshot
+ * mostSimilarDocument = querySnapshot.getDocuments().get(0);}
+ *
+ * @param vectorField A {@link FieldPath} specifying the vector field to search on.
+ * @param queryVector The {@link VectorValue} used to measure the distance from `vectorField`
+ * values in the documents.
+ * @param limit The upper bound of documents to return, must be a positive integer with a maximum
+ * value of 1000.
+ * @param distanceMeasure What type of distance is calculated when performing the query. See
+ * {@link VectorQuery.DistanceMeasure}.
+ * @param vectorQueryOptions Optional arguments for VectorQueries, see {@link VectorQueryOptions}.
+ * @return an {@link VectorQuery} that performs vector distance (similarity) search with the given
+ * parameters.
+ */
+ public VectorQuery findNearest(
+ FieldPath vectorField,
+ VectorValue queryVector,
+ int limit,
+ VectorQuery.DistanceMeasure distanceMeasure,
+ VectorQueryOptions vectorQueryOptions) {
+ if (limit <= 0) {
+ throw FirestoreException.forInvalidArgument(
+ "Not a valid positive `limit` number. `limit` must be larger than 0.");
+ }
+
+ if (queryVector.size() == 0) {
+ throw FirestoreException.forInvalidArgument(
+ "Not a valid vector size. `queryVector` size must be larger than 0.");
+ }
+
+ return new VectorQuery(
+ this, vectorField, queryVector, limit, distanceMeasure, vectorQueryOptions);
+ }
+
/**
* Returns true if this Query is equal to the provided object.
*
diff --git a/google-cloud-firestore/src/main/java/com/google/cloud/firestore/QueryDocumentSnapshot.java b/google-cloud-firestore/src/main/java/com/google/cloud/firestore/QueryDocumentSnapshot.java
index c0575586d..ab3442415 100644
--- a/google-cloud-firestore/src/main/java/com/google/cloud/firestore/QueryDocumentSnapshot.java
+++ b/google-cloud-firestore/src/main/java/com/google/cloud/firestore/QueryDocumentSnapshot.java
@@ -87,4 +87,8 @@ public T toObject(@Nonnull Class valueType) {
Preconditions.checkNotNull(result, "Object in a QueryDocumentSnapshot should be non-null");
return result;
}
+
+ static int compareDocumentId(QueryDocumentSnapshot doc1, QueryDocumentSnapshot doc2) {
+ return doc1.getReference().getResourcePath().compareTo(doc2.getReference().getResourcePath());
+ }
}
diff --git a/google-cloud-firestore/src/main/java/com/google/cloud/firestore/QuerySnapshot.java b/google-cloud-firestore/src/main/java/com/google/cloud/firestore/QuerySnapshot.java
index 494a298e4..1b1c0cbf9 100644
--- a/google-cloud-firestore/src/main/java/com/google/cloud/firestore/QuerySnapshot.java
+++ b/google-cloud-firestore/src/main/java/com/google/cloud/firestore/QuerySnapshot.java
@@ -17,85 +17,29 @@
package com.google.cloud.firestore;
import com.google.cloud.Timestamp;
-import com.google.cloud.firestore.DocumentChange.Type;
-import com.google.cloud.firestore.encoding.CustomClassMapper;
-import com.google.common.collect.ImmutableList;
-import java.util.ArrayList;
-import java.util.Collections;
-import java.util.Iterator;
import java.util.List;
-import java.util.Objects;
-import javax.annotation.Nonnull;
/**
* A QuerySnapshot contains the results of a query. It can contain zero or more DocumentSnapshot
* objects.
*/
-public abstract class QuerySnapshot implements Iterable {
-
- private final Query query;
- private final Timestamp readTime;
+public class QuerySnapshot extends GenericQuerySnapshot {
+ protected QuerySnapshot(Query query, Timestamp readTime) {
+ super(query, readTime, null, null);
+ }
- protected QuerySnapshot(Query query, Timestamp readTime) { // Elevated access level for mocking.
- this.query = query;
- this.readTime = readTime;
+ protected QuerySnapshot(
+ Query query,
+ Timestamp readTime,
+ final List documents,
+ final List documentChanges) {
+ super(query, readTime, documents, documentChanges);
}
/** Creates a new QuerySnapshot representing the results of a Query with added documents. */
public static QuerySnapshot withDocuments(
final Query query, Timestamp readTime, final List documents) {
- return new QuerySnapshot(query, readTime) {
- volatile ImmutableList documentChanges;
-
- @Nonnull
- @Override
- public List getDocuments() {
- return Collections.unmodifiableList(documents);
- }
-
- @Nonnull
- @Override
- public List getDocumentChanges() {
- if (documentChanges == null) {
- synchronized (documents) {
- if (documentChanges == null) {
- int size = documents.size();
- ImmutableList.Builder builder =
- ImmutableList.builderWithExpectedSize(size);
- for (int i = 0; i < size; ++i) {
- builder.add(new DocumentChange(documents.get(i), Type.ADDED, -1, i));
- }
- documentChanges = builder.build();
- }
- }
- }
- return documentChanges;
- }
-
- @Override
- public int size() {
- return documents.size();
- }
-
- @Override
- public boolean equals(Object o) {
- if (this == o) {
- return true;
- }
- if (o == null || getClass() != o.getClass()) {
- return false;
- }
- QuerySnapshot that = (QuerySnapshot) o;
- return Objects.equals(query, that.query)
- && Objects.equals(this.size(), that.size())
- && Objects.equals(this.getDocuments(), that.getDocuments());
- }
-
- @Override
- public int hashCode() {
- return Objects.hash(query, this.getDocuments());
- }
- };
+ return new QuerySnapshot(query, readTime, documents, null);
}
/** Creates a new QuerySnapshot representing a snapshot of a Query with changed documents. */
@@ -104,134 +48,6 @@ public static QuerySnapshot withChanges(
Timestamp readTime,
final DocumentSet documentSet,
final List documentChanges) {
- return new QuerySnapshot(query, readTime) {
- volatile List documents;
-
- @Nonnull
- @Override
- public List getDocuments() {
- if (documents == null) {
- synchronized (documentSet) {
- if (documents == null) {
- documents = documentSet.toList();
- }
- }
- }
- return Collections.unmodifiableList(documents);
- }
-
- @Nonnull
- @Override
- public List getDocumentChanges() {
- return Collections.unmodifiableList(documentChanges);
- }
-
- @Override
- public int size() {
- return documentSet.size();
- }
-
- @Override
- public boolean equals(Object o) {
- if (this == o) {
- return true;
- }
- if (o == null || getClass() != o.getClass()) {
- return false;
- }
- QuerySnapshot that = (QuerySnapshot) o;
- return Objects.equals(query, that.query)
- && Objects.equals(this.size(), that.size())
- && Objects.equals(this.getDocumentChanges(), that.getDocumentChanges())
- && Objects.equals(this.getDocuments(), that.getDocuments());
- }
-
- @Override
- public int hashCode() {
- return Objects.hash(query, this.getDocumentChanges(), this.getDocuments());
- }
- };
- }
-
- /**
- * Returns the query for the snapshot.
- *
- * @return The backing query that produced this snapshot.
- */
- @Nonnull
- public Query getQuery() {
- return query;
- }
-
- /**
- * Returns the time at which this snapshot was read.
- *
- * @return The read time of this snapshot.
- */
- @Nonnull
- public Timestamp getReadTime() {
- return readTime;
- }
-
- /**
- * Returns the documents in this QuerySnapshot as a List in order of the query.
- *
- * @return The list of documents.
- */
- @Nonnull
- public abstract List getDocuments();
-
- /**
- * Returns the list of documents that changed since the last snapshot. If it's the first snapshot
- * all documents will be in the list as added changes.
- *
- * @return The list of documents that changed since the last snapshot.
- */
- @Nonnull
- public abstract List getDocumentChanges();
-
- /** Returns true if there are no documents in the QuerySnapshot. */
- public boolean isEmpty() {
- return this.size() == 0;
+ return new QuerySnapshot(query, readTime, documentSet.toList(), documentChanges);
}
-
- /** Returns the number of documents in the QuerySnapshot. */
- public abstract int size();
-
- @Override
- @Nonnull
- public Iterator iterator() {
- return getDocuments().iterator();
- }
-
- /**
- * Returns the contents of the documents in the QuerySnapshot, converted to the provided class, as
- * a list.
- *
- * @param clazz The POJO type used to convert the documents in the list.
- */
- @Nonnull
- public List toObjects(@Nonnull Class clazz) {
- List documents = getDocuments();
- List results = new ArrayList<>(documents.size());
- for (DocumentSnapshot documentSnapshot : documents) {
- results.add(
- CustomClassMapper.convertToCustomClass(
- documentSnapshot.getData(), clazz, documentSnapshot.getReference()));
- }
-
- return results;
- }
-
- /**
- * Returns true if the document data in this QuerySnapshot equals the provided snapshot.
- *
- * @param obj The object to compare against.
- * @return Whether this QuerySnapshot is equal to the provided object.
- */
- @Override
- public abstract boolean equals(Object obj);
-
- @Override
- public abstract int hashCode();
}
diff --git a/google-cloud-firestore/src/main/java/com/google/cloud/firestore/StreamableQuery.java b/google-cloud-firestore/src/main/java/com/google/cloud/firestore/StreamableQuery.java
new file mode 100644
index 000000000..24be9e69c
--- /dev/null
+++ b/google-cloud-firestore/src/main/java/com/google/cloud/firestore/StreamableQuery.java
@@ -0,0 +1,400 @@
+/*
+ * Copyright 2024 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.google.cloud.firestore;
+
+import static com.google.cloud.firestore.telemetry.TraceUtil.*;
+import static com.google.common.collect.Lists.reverse;
+
+import com.google.api.core.ApiFuture;
+import com.google.api.core.SettableApiFuture;
+import com.google.api.gax.rpc.ApiStreamObserver;
+import com.google.api.gax.rpc.ResponseObserver;
+import com.google.api.gax.rpc.StatusCode;
+import com.google.api.gax.rpc.StreamController;
+import com.google.cloud.Timestamp;
+import com.google.cloud.firestore.telemetry.TraceUtil;
+import com.google.cloud.firestore.v1.FirestoreSettings;
+import com.google.common.collect.ImmutableMap;
+import com.google.firestore.v1.Document;
+import com.google.firestore.v1.RunQueryRequest;
+import com.google.firestore.v1.RunQueryResponse;
+import com.google.protobuf.ByteString;
+import io.grpc.Status;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.List;
+import java.util.Set;
+import java.util.concurrent.atomic.AtomicReference;
+import javax.annotation.Nonnull;
+import javax.annotation.Nullable;
+import org.threeten.bp.Duration;
+
+/**
+ * Represents a query whose results can be streamed. If the stream fails with a retryable error,
+ * implementations of StreamableQuery can optionally support retries with a cursor, as indicated by
+ * `isRetryableWithCursor`. Retrying with a cursor means that the StreamableQuery can be resumed
+ * where it failed by first calling `startAfter(lastDocumentReceived)`.
+ */
+public abstract class StreamableQuery {
+ final Query.QueryOptions options;
+ final FirestoreRpcContext> rpcContext;
+
+ StreamableQuery(FirestoreRpcContext> rpcContext, Query.QueryOptions options) {
+ this.rpcContext = rpcContext;
+ this.options = options;
+ }
+
+ abstract RunQueryRequest.Builder toRunQueryRequestBuilder(
+ @Nullable final ByteString transactionId,
+ @Nullable final Timestamp readTime,
+ @Nullable ExplainOptions explainOptions);
+
+ abstract boolean isRetryableWithCursor();
+
+ abstract StreamableQuery startAfter(@Nonnull DocumentSnapshot snapshot);
+
+ abstract SnapshotType createSnaphot(
+ Timestamp readTime, final List documents);
+
+ /**
+ * Gets the Firestore instance associated with this query.
+ *
+ * @return The Firestore instance associated with this query.
+ */
+ @Nonnull
+ public Firestore getFirestore() {
+ return rpcContext.getFirestore();
+ }
+ /**
+ * Executes the query and returns the results as QuerySnapshot.
+ *
+ * @return An ApiFuture that will be resolved with the results of the Query.
+ */
+ @Nonnull
+ public abstract ApiFuture get();
+
+ /**
+ * Executes the query and returns the results as QuerySnapshot.
+ *
+ * @return An ApiFuture that will be resolved with the results of the Query.
+ */
+ ApiFuture get(
+ @Nullable ByteString transactionId, @Nullable Timestamp requestReadTime) {
+ TraceUtil.Span span =
+ getFirestore()
+ .getOptions()
+ .getTraceUtil()
+ .startSpan(
+ transactionId == null
+ ? TraceUtil.SPAN_NAME_QUERY_GET
+ : TraceUtil.SPAN_NAME_TRANSACTION_GET_QUERY);
+ try (Scope ignored = span.makeCurrent()) {
+ final SettableApiFuture result = SettableApiFuture.create();
+ internalStream(
+ new ApiStreamObserver() {
+ final List documentSnapshots = new ArrayList<>();
+ Timestamp responseReadTime;
+
+ @Override
+ public void onNext(RunQueryResponse runQueryResponse) {
+ if (runQueryResponse.hasDocument()) {
+ Document document = runQueryResponse.getDocument();
+ QueryDocumentSnapshot documentSnapshot =
+ QueryDocumentSnapshot.fromDocument(
+ rpcContext, Timestamp.fromProto(runQueryResponse.getReadTime()), document);
+ documentSnapshots.add(documentSnapshot);
+ }
+ if (responseReadTime == null) {
+ responseReadTime = Timestamp.fromProto(runQueryResponse.getReadTime());
+ }
+ }
+
+ @Override
+ public void onError(Throwable throwable) {
+ result.setException(throwable);
+ }
+
+ @Override
+ public void onCompleted() {
+ // The results for limitToLast queries need to be flipped since we reversed the
+ // ordering constraints before sending the query to the backend.
+ List resultView =
+ Query.LimitType.Last.equals(options.getLimitType())
+ ? reverse(documentSnapshots)
+ : documentSnapshots;
+ SnapshotType querySnapshot = createSnaphot(responseReadTime, resultView);
+ result.set(querySnapshot);
+ }
+ },
+ /* startTimeNanos= */ rpcContext.getClock().nanoTime(),
+ transactionId,
+ /* readTime= */ requestReadTime,
+ /* explainOptions= */ null,
+ /* isRetryRequestWithCursor= */ false);
+
+ span.endAtFuture(result);
+ return result;
+ } catch (Exception error) {
+ span.end(error);
+ throw error;
+ }
+ }
+
+ /**
+ * Plans and optionally executes this query. Returns an ApiFuture that will be resolved with the
+ * planner information, statistics from the query execution (if any), and the query results (if
+ * any).
+ *
+ * @return An ApiFuture that will be resolved with the planner information, statistics from the
+ * query execution (if any), and the query results (if any).
+ */
+ @Nonnull
+ public ApiFuture> explain(ExplainOptions options) {
+ TraceUtil.Span span =
+ getFirestore().getOptions().getTraceUtil().startSpan(TraceUtil.SPAN_NAME_QUERY_GET);
+
+ try (Scope ignored = span.makeCurrent()) {
+ final SettableApiFuture> result = SettableApiFuture.create();
+ internalStream(
+ new ApiStreamObserver() {
+ @Nullable List documentSnapshots = null;
+ Timestamp readTime;
+ ExplainMetrics metrics;
+
+ @Override
+ public void onNext(RunQueryResponse runQueryResponse) {
+ if (runQueryResponse.hasDocument()) {
+ if (documentSnapshots == null) {
+ documentSnapshots = new ArrayList<>();
+ }
+
+ Document document = runQueryResponse.getDocument();
+ QueryDocumentSnapshot documentSnapshot =
+ QueryDocumentSnapshot.fromDocument(
+ rpcContext, Timestamp.fromProto(runQueryResponse.getReadTime()), document);
+ documentSnapshots.add(documentSnapshot);
+ }
+
+ if (readTime == null) {
+ readTime = Timestamp.fromProto(runQueryResponse.getReadTime());
+ }
+
+ if (runQueryResponse.hasExplainMetrics()) {
+ metrics = new ExplainMetrics(runQueryResponse.getExplainMetrics());
+ if (documentSnapshots == null && metrics.getExecutionStats() != null) {
+ // This indicates that the query was executed, but no documents
+ // had matched the query. Create an empty list.
+ documentSnapshots = Collections.emptyList();
+ }
+ }
+ }
+
+ @Override
+ public void onError(Throwable throwable) {
+ result.setException(throwable);
+ }
+
+ @Override
+ public void onCompleted() {
+ @Nullable SnapshotType snapshot = null;
+ if (documentSnapshots != null) {
+ // The results for limitToLast queries need to be flipped since we reversed the
+ // ordering constraints before sending the query to the backend.
+ List resultView =
+ Query.LimitType.Last.equals(StreamableQuery.this.options.getLimitType())
+ ? reverse(documentSnapshots)
+ : documentSnapshots;
+ snapshot = createSnaphot(readTime, resultView);
+ }
+ result.set(new ExplainResults<>(metrics, snapshot));
+ }
+ },
+ /* startTimeNanos= */ rpcContext.getClock().nanoTime(),
+ /* transactionId= */ null,
+ /* readTime= */ null,
+ /* explainOptions= */ options,
+ /* isRetryRequestWithCursor= */ false);
+
+ span.endAtFuture(result);
+ return result;
+ } catch (Exception error) {
+ span.end(error);
+ throw error;
+ }
+ }
+
+ protected void internalStream(
+ final ApiStreamObserver runQueryResponseObserver,
+ final long startTimeNanos,
+ @Nullable final ByteString transactionId,
+ @Nullable final Timestamp readTime,
+ @Nullable final ExplainOptions explainOptions,
+ final boolean isRetryRequestWithCursor) {
+ TraceUtil traceUtil = getFirestore().getOptions().getTraceUtil();
+ // To reduce the size of traces, we only register one event for every 100 responses
+ // that we receive from the server.
+ final int NUM_RESPONSES_PER_TRACE_EVENT = 100;
+
+ TraceUtil.Span currentSpan = traceUtil.currentSpan();
+ currentSpan.addEvent(
+ TraceUtil.SPAN_NAME_RUN_QUERY,
+ new ImmutableMap.Builder()
+ .put(ATTRIBUTE_KEY_IS_TRANSACTIONAL, transactionId != null)
+ .put(ATTRIBUTE_KEY_IS_RETRY_WITH_CURSOR, isRetryRequestWithCursor)
+ .build());
+
+ final AtomicReference lastReceivedDocument = new AtomicReference<>();
+
+ ResponseObserver observer =
+ new ResponseObserver() {
+ Timestamp readTime;
+ boolean firstResponse = false;
+ int numDocuments = 0;
+
+ // The stream's `onComplete()` could be called more than once,
+ // this flag makes sure only the first one is actually processed.
+ boolean hasCompleted = false;
+
+ @Override
+ public void onStart(StreamController streamController) {}
+
+ @Override
+ public void onResponse(RunQueryResponse response) {
+ if (!firstResponse) {
+ firstResponse = true;
+ currentSpan.addEvent(TraceUtil.SPAN_NAME_RUN_QUERY + ": First Response");
+ }
+
+ runQueryResponseObserver.onNext(response);
+
+ if (response.hasDocument()) {
+ numDocuments++;
+ if (numDocuments % NUM_RESPONSES_PER_TRACE_EVENT == 0) {
+ currentSpan.addEvent(
+ TraceUtil.SPAN_NAME_RUN_QUERY + ": Received " + numDocuments + " documents");
+ }
+ Document document = response.getDocument();
+ QueryDocumentSnapshot documentSnapshot =
+ QueryDocumentSnapshot.fromDocument(
+ rpcContext, Timestamp.fromProto(response.getReadTime()), document);
+ lastReceivedDocument.set(documentSnapshot);
+ }
+
+ if (response.getDone()) {
+ currentSpan.addEvent(
+ TraceUtil.SPAN_NAME_RUN_QUERY + ": Received RunQueryResponse.Done");
+ onComplete();
+ }
+ }
+
+ @Override
+ public void onError(Throwable throwable) {
+ QueryDocumentSnapshot cursor = lastReceivedDocument.get();
+ if (isRetryableWithCursor() && shouldRetry(cursor, throwable)) {
+ currentSpan.addEvent(
+ TraceUtil.SPAN_NAME_RUN_QUERY + ": Retryable Error",
+ Collections.singletonMap("error.message", throwable.getMessage()));
+
+ startAfter(cursor)
+ .internalStream(
+ runQueryResponseObserver,
+ startTimeNanos,
+ /* transactionId= */ null,
+ options.getRequireConsistency() ? cursor.getReadTime() : null,
+ explainOptions,
+ /* isRetryRequestWithCursor= */ true);
+ } else {
+ currentSpan.addEvent(
+ TraceUtil.SPAN_NAME_RUN_QUERY + ": Error",
+ Collections.singletonMap("error.message", throwable.getMessage()));
+ runQueryResponseObserver.onError(throwable);
+ }
+ }
+
+ @Override
+ public void onComplete() {
+ if (hasCompleted) return;
+ hasCompleted = true;
+ currentSpan.addEvent(
+ TraceUtil.SPAN_NAME_RUN_QUERY + ": Completed",
+ Collections.singletonMap(ATTRIBUTE_KEY_DOC_COUNT, numDocuments));
+ runQueryResponseObserver.onCompleted();
+ }
+
+ boolean shouldRetry(DocumentSnapshot lastDocument, Throwable t) {
+ if (lastDocument == null) {
+ // Only retry if we have received a single result. Retries for RPCs with initial
+ // failure are handled by Google Gax, which also implements backoff.
+ return false;
+ }
+
+ // Do not retry EXPLAIN requests because it'd be executing
+ // multiple queries. This means stats would have to be aggregated,
+ // and that may not even make sense for many statistics.
+ if (explainOptions != null) {
+ return false;
+ }
+
+ Set retryableCodes =
+ FirestoreSettings.newBuilder().runQuerySettings().getRetryableCodes();
+ return shouldRetryQuery(t, transactionId, startTimeNanos, retryableCodes);
+ }
+ };
+
+ rpcContext.streamRequest(
+ toRunQueryRequestBuilder(transactionId, readTime, explainOptions).build(),
+ observer,
+ rpcContext.getClient().runQueryCallable());
+ }
+
+ /** Returns whether a query that failed in the given scenario should be retried. */
+ boolean shouldRetryQuery(
+ Throwable throwable,
+ @Nullable ByteString transactionId,
+ long startTimeNanos,
+ Set retryableCodes) {
+ if (transactionId != null) {
+ // Transactional queries are retried via the transaction runner.
+ return false;
+ }
+
+ if (!isRetryableError(throwable, retryableCodes)) {
+ return false;
+ }
+
+ if (rpcContext.getTotalRequestTimeout().isZero()) {
+ return true;
+ }
+
+ Duration duration = Duration.ofNanos(rpcContext.getClock().nanoTime() - startTimeNanos);
+ return duration.compareTo(rpcContext.getTotalRequestTimeout()) < 0;
+ }
+
+ /** Verifies whether the given exception is retryable based on the RunQuery configuration. */
+ private boolean isRetryableError(Throwable throwable, Set retryableCodes) {
+ if (!(throwable instanceof FirestoreException)) {
+ return false;
+ }
+ Status status = ((FirestoreException) throwable).getStatus();
+ for (StatusCode.Code code : retryableCodes) {
+ if (code.equals(StatusCode.Code.valueOf(status.getCode().name()))) {
+ return true;
+ }
+ }
+ return false;
+ }
+}
diff --git a/google-cloud-firestore/src/main/java/com/google/cloud/firestore/VectorQuery.java b/google-cloud-firestore/src/main/java/com/google/cloud/firestore/VectorQuery.java
new file mode 100644
index 000000000..2f8c2ac42
--- /dev/null
+++ b/google-cloud-firestore/src/main/java/com/google/cloud/firestore/VectorQuery.java
@@ -0,0 +1,192 @@
+/*
+ * Copyright 2024 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.google.cloud.firestore;
+
+import com.google.api.core.ApiFuture;
+import com.google.cloud.Timestamp;
+import com.google.firestore.v1.RunQueryRequest;
+import com.google.firestore.v1.StructuredQuery;
+import com.google.protobuf.ByteString;
+import java.util.List;
+import java.util.Objects;
+import javax.annotation.Nonnull;
+import javax.annotation.Nullable;
+
+/**
+ * A query that finds the documents whose vector fields are closest to a certain query vector.
+ * Create an instance of `VectorQuery` with {@link Query#findNearest}.
+ */
+public final class VectorQuery extends StreamableQuery {
+ final Query query;
+ final FieldPath vectorField;
+ final VectorValue queryVector;
+ final int limit;
+ final DistanceMeasure distanceMeasure;
+ final VectorQueryOptions options;
+
+ /** Creates a VectorQuery */
+ VectorQuery(
+ Query query,
+ FieldPath vectorField,
+ VectorValue queryVector,
+ int limit,
+ DistanceMeasure distanceMeasure,
+ VectorQueryOptions options) {
+ super(query.rpcContext, query.options);
+
+ this.query = query;
+ this.options = options;
+ this.vectorField = vectorField;
+ this.queryVector = queryVector;
+ this.limit = limit;
+ this.distanceMeasure = distanceMeasure;
+ }
+
+ /**
+ * Executes the query and returns the results as {@link QuerySnapshot}.
+ *
+ * @return An ApiFuture that will be resolved with the results of the VectorQuery.
+ */
+ @Override
+ @Nonnull
+ public ApiFuture get() {
+ return get(null, null);
+ }
+
+ /**
+ * Plans and optionally executes this VectorQuery. Returns an ApiFuture that will be resolved with
+ * the planner information, statistics from the query execution (if any), and the query results
+ * (if any).
+ *
+ * @return An ApiFuture that will be resolved with the planner information, statistics from the
+ * query execution (if any), and the query results (if any).
+ */
+ @Override
+ @Nonnull
+ public ApiFuture> explain(ExplainOptions options) {
+ return super.explain(options);
+ }
+
+ /**
+ * Returns true if this VectorQuery is equal to the provided object.
+ *
+ * @param obj The object to compare against.
+ * @return Whether this VectorQuery is equal to the provided object.
+ */
+ @Override
+ public boolean equals(Object obj) {
+ if (this == obj) {
+ return true;
+ }
+ if (obj == null || !(obj instanceof VectorQuery)) {
+ return false;
+ }
+ VectorQuery otherQuery = (VectorQuery) obj;
+ return Objects.equals(query, otherQuery.query)
+ && Objects.equals(vectorField, otherQuery.vectorField)
+ && Objects.equals(queryVector, otherQuery.queryVector)
+ && Objects.equals(options, otherQuery.options)
+ && limit == otherQuery.limit
+ && distanceMeasure == otherQuery.distanceMeasure;
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hash(query, options, vectorField, queryVector, limit, distanceMeasure);
+ }
+
+ @Override
+ protected RunQueryRequest.Builder toRunQueryRequestBuilder(
+ @Nullable final ByteString transactionId,
+ @Nullable final Timestamp readTime,
+ @Nullable ExplainOptions explainOptions) {
+
+ // Builder for the base query
+ RunQueryRequest.Builder requestBuilder =
+ query.toRunQueryRequestBuilder(transactionId, readTime, explainOptions);
+
+ // Builder for find nearest
+ StructuredQuery.FindNearest.Builder findNearestBuilder =
+ requestBuilder.getStructuredQueryBuilder().getFindNearestBuilder();
+ findNearestBuilder.getQueryVectorBuilder().setMapValue(this.queryVector.toProto());
+ findNearestBuilder.getLimitBuilder().setValue(this.limit);
+ findNearestBuilder.setDistanceMeasure(toProto(this.distanceMeasure));
+ findNearestBuilder.getVectorFieldBuilder().setFieldPath(this.vectorField.toString());
+
+ if (this.options != null) {
+ if (this.options.getDistanceThreshold() != null) {
+ findNearestBuilder
+ .getDistanceThresholdBuilder()
+ .setValue(this.options.getDistanceThreshold().doubleValue());
+ }
+ if (this.options.getDistanceResultField() != null) {
+ findNearestBuilder.setDistanceResultField(this.options.getDistanceResultField().toString());
+ }
+ }
+
+ return requestBuilder;
+ }
+
+ private static StructuredQuery.FindNearest.DistanceMeasure toProto(
+ DistanceMeasure distanceMeasure) {
+ switch (distanceMeasure) {
+ case COSINE:
+ return StructuredQuery.FindNearest.DistanceMeasure.COSINE;
+ case EUCLIDEAN:
+ return StructuredQuery.FindNearest.DistanceMeasure.EUCLIDEAN;
+ case DOT_PRODUCT:
+ return StructuredQuery.FindNearest.DistanceMeasure.DOT_PRODUCT;
+ default:
+ return StructuredQuery.FindNearest.DistanceMeasure.UNRECOGNIZED;
+ }
+ }
+
+ @Override
+ boolean isRetryableWithCursor() {
+ return false;
+ }
+
+ @Override
+ VectorQuery startAfter(@Nonnull DocumentSnapshot snapshot) {
+ throw new RuntimeException("Not implemented");
+ }
+
+ @Override
+ VectorQuerySnapshot createSnaphot(
+ Timestamp readTime, final List documents) {
+ return VectorQuerySnapshot.withDocuments(this, readTime, documents);
+ }
+
+ /**
+ * The distance measure to use when comparing vectors in a {@link VectorQuery}.
+ *
+ * @see com.google.cloud.firestore.Query#findNearest
+ */
+ public enum DistanceMeasure {
+ /**
+ * COSINE distance compares vectors based on the angle between them, which allows you to measure
+ * similarity that isn't based on the vectors' magnitude. We recommend using DOT_PRODUCT with
+ * unit normalized vectors instead of COSINE distance, which is mathematically equivalent with
+ * better performance.
+ */
+ COSINE,
+ /** Measures the EUCLIDEAN distance between the vectors. */
+ EUCLIDEAN,
+ /** Similar to cosine but is affected by the magnitude of the vectors. */
+ DOT_PRODUCT
+ }
+}
diff --git a/google-cloud-firestore/src/main/java/com/google/cloud/firestore/VectorQueryOptions.java b/google-cloud-firestore/src/main/java/com/google/cloud/firestore/VectorQueryOptions.java
new file mode 100644
index 000000000..38ca68d23
--- /dev/null
+++ b/google-cloud-firestore/src/main/java/com/google/cloud/firestore/VectorQueryOptions.java
@@ -0,0 +1,161 @@
+/*
+ * Copyright 2024 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.google.cloud.firestore;
+
+import java.util.Objects;
+import javax.annotation.Nullable;
+
+/**
+ * Specifies the behavior of the {@link VectorQuery} generated by a call to {@link
+ * Query#findNearest}.
+ */
+public class VectorQueryOptions {
+ private final @Nullable FieldPath distanceResultField;
+
+ private final @Nullable Double distanceThreshold;
+
+ @Nullable
+ public FieldPath getDistanceResultField() {
+ return distanceResultField;
+ }
+
+ @Nullable
+ public Double getDistanceThreshold() {
+ return distanceThreshold;
+ }
+
+ VectorQueryOptions(VectorQueryOptions.Builder builder) {
+ this.distanceThreshold = builder.distanceThreshold;
+ this.distanceResultField = builder.distanceResultField;
+ }
+
+ public static VectorQueryOptions.Builder newBuilder() {
+ return new VectorQueryOptions.Builder();
+ }
+
+ public static final class Builder {
+ /**
+ * Returns the name of the field that will be set on each returned DocumentSnapshot, which will
+ * contain the computed distance for the document. If `null`, then the computed distance will
+ * not be returned. Default value: `null`.
+ *
+ * Set this value with {@link VectorQueryOptions.Builder#setDistanceResultField(FieldPath)}
+ * or {@link VectorQueryOptions.Builder#setDistanceResultField(String)}.
+ */
+ private @Nullable FieldPath distanceResultField;
+
+ /**
+ * Specifies a threshold for which no less similar documents will be returned. If `null`, then
+ * the computed distance will not be returned. Default value: `null`.
+ *
+ *
Set this value with {@link VectorQueryOptions.Builder#setDistanceThreshold(Double)}.
+ */
+ private @Nullable Double distanceThreshold;
+
+ private Builder() {
+ distanceThreshold = null;
+ distanceResultField = null;
+ }
+
+ private Builder(VectorQueryOptions options) {
+ this.distanceThreshold = options.distanceThreshold;
+ this.distanceResultField = options.distanceResultField;
+ }
+
+ /**
+ * Specifies the name of the field that will be set on each returned DocumentSnapshot, which
+ * will contain the computed distance for the document. If `null`, then the computed distance
+ * will not be returned. Default value: `null`.
+ *
+ * @param fieldPath A string value specifying the distance result field.
+ */
+ public Builder setDistanceResultField(@Nullable String fieldPath) {
+ this.distanceResultField = FieldPath.fromDotSeparatedString(fieldPath);
+ return this;
+ }
+
+ /**
+ * Specifies the name of the field that will be set on each returned DocumentSnapshot, which
+ * will contain the computed distance for the document. If `null`, then the computed distance
+ * will not be returned. Default value: `null`.
+ *
+ * @param fieldPath A {@link FieldPath} value specifying the distance result field.
+ */
+ public Builder setDistanceResultField(@Nullable FieldPath fieldPath) {
+ this.distanceResultField = fieldPath;
+ return this;
+ }
+
+ /**
+ * Specifies a threshold for which no less similar documents will be returned. The behavior of
+ * the specified `distanceMeasure` will affect the meaning of the distance threshold.
+ *
+ *
+ *
+ *
+ * - For `distanceMeasure: "EUCLIDEAN"`, the meaning of `distanceThreshold` is: {@code
+ * SELECT docs WHERE euclidean_distance <= distanceThreshold}
+ *
- For `distanceMeasure: "COSINE"`, the meaning of `distanceThreshold` is: {@code SELECT
+ * docs WHERE cosine_distance <= distanceThreshold}
+ *
- For `distanceMeasure: "DOT_PRODUCT"`, the meaning of `distanceThreshold` is: {@code
+ * SELECT docs WHERE dot_product_distance >= distanceThreshold}
+ *
+ *
+ * If `null`, then the computed distance will not be returned. Default value: `null`.
+ *
+ * @param distanceThreshold A Double value specifying the distance threshold.
+ */
+ public Builder setDistanceThreshold(@Nullable Double distanceThreshold) {
+ this.distanceThreshold = distanceThreshold;
+ return this;
+ }
+
+ public VectorQueryOptions build() {
+ return new VectorQueryOptions(this);
+ }
+ }
+
+ /**
+ * Returns true if this VectorQueryOptions is equal to the provided object.
+ *
+ * @param obj The object to compare against.
+ * @return Whether this VectorQueryOptions is equal to the provided object.
+ */
+ @Override
+ public boolean equals(Object obj) {
+ if (this == obj) {
+ return true;
+ }
+ if (obj == null || !(obj instanceof VectorQueryOptions)) {
+ return false;
+ }
+ VectorQueryOptions otherOptions = (VectorQueryOptions) obj;
+ return Objects.equals(distanceResultField, otherOptions.distanceResultField)
+ && Objects.equals(distanceThreshold, otherOptions.distanceThreshold);
+ }
+
+ /** Default VectorQueryOptions instance. */
+ private static VectorQueryOptions DEFAULT = newBuilder().build();
+
+ /**
+ * Returns a default {@code FirestoreOptions} instance. Note: package private until API review can
+ * be completed.
+ */
+ static VectorQueryOptions getDefaultInstance() {
+ return DEFAULT;
+ }
+}
diff --git a/google-cloud-firestore/src/main/java/com/google/cloud/firestore/VectorQuerySnapshot.java b/google-cloud-firestore/src/main/java/com/google/cloud/firestore/VectorQuerySnapshot.java
new file mode 100644
index 000000000..528512fad
--- /dev/null
+++ b/google-cloud-firestore/src/main/java/com/google/cloud/firestore/VectorQuerySnapshot.java
@@ -0,0 +1,43 @@
+/*
+ * Copyright 2024 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.google.cloud.firestore;
+
+import com.google.cloud.Timestamp;
+import java.util.List;
+
+/**
+ * A VectorQuerySnapshot contains the results of a VectorQuery. It can contain zero or more
+ * DocumentSnapshot objects.
+ */
+public class VectorQuerySnapshot extends GenericQuerySnapshot {
+ protected VectorQuerySnapshot(
+ VectorQuery query,
+ Timestamp readTime,
+ final List documents,
+ final List documentChanges) {
+ super(query, readTime, documents, documentChanges);
+ }
+
+ /**
+ * Creates a new VectorQuerySnapshot representing the results of a VectorQuery with added
+ * documents.
+ */
+ public static VectorQuerySnapshot withDocuments(
+ final VectorQuery query, Timestamp readTime, final List documents) {
+ return new VectorQuerySnapshot(query, readTime, documents, null);
+ }
+}
diff --git a/google-cloud-firestore/src/main/java/com/google/cloud/firestore/VectorValue.java b/google-cloud-firestore/src/main/java/com/google/cloud/firestore/VectorValue.java
index b95e8d08a..347e721f1 100644
--- a/google-cloud-firestore/src/main/java/com/google/cloud/firestore/VectorValue.java
+++ b/google-cloud-firestore/src/main/java/com/google/cloud/firestore/VectorValue.java
@@ -69,4 +69,11 @@ public int hashCode() {
MapValue toProto() {
return UserDataConverter.encodeVector(this.values);
}
+
+ /**
+ * Returns the number of dimensions of the vector. Note: package private until API review is done.
+ */
+ int size() {
+ return this.values.length;
+ }
}
diff --git a/google-cloud-firestore/src/test/java/com/google/cloud/firestore/LocalFirestoreHelper.java b/google-cloud-firestore/src/test/java/com/google/cloud/firestore/LocalFirestoreHelper.java
index 70d497b75..5ccd9f163 100644
--- a/google-cloud-firestore/src/test/java/com/google/cloud/firestore/LocalFirestoreHelper.java
+++ b/google-cloud-firestore/src/test/java/com/google/cloud/firestore/LocalFirestoreHelper.java
@@ -63,10 +63,7 @@
import com.google.firestore.v1.Write;
import com.google.gson.Gson;
import com.google.gson.reflect.TypeToken;
-import com.google.protobuf.ByteString;
-import com.google.protobuf.Empty;
-import com.google.protobuf.Message;
-import com.google.protobuf.NullValue;
+import com.google.protobuf.*;
import com.google.type.LatLng;
import java.lang.reflect.Type;
import java.math.BigInteger;
@@ -649,6 +646,39 @@ public static StructuredQuery filter(StructuredQuery.FieldFilter.Operator operat
return filter(operator, "foo", "bar");
}
+ public static StructuredQuery findNearest(
+ String fieldPath,
+ double[] queryVector,
+ int limit,
+ StructuredQuery.FindNearest.DistanceMeasure measure) {
+ ArrayValue.Builder vectorArrayBuilder = ArrayValue.newBuilder();
+ for (double d : queryVector) {
+ vectorArrayBuilder.addValues(Value.newBuilder().setDoubleValue(d));
+ }
+
+ StructuredQuery.FindNearest.Builder findNearest =
+ StructuredQuery.FindNearest.newBuilder()
+ .setVectorField(StructuredQuery.FieldReference.newBuilder().setFieldPath(fieldPath))
+ .setQueryVector(
+ Value.newBuilder()
+ .setMapValue(
+ MapValue.newBuilder()
+ .putFields(
+ "__type__", Value.newBuilder().setStringValue("__vector__").build())
+ .putFields(
+ "value",
+ Value.newBuilder()
+ .setArrayValue(vectorArrayBuilder.build())
+ .build())))
+ .setLimit(Int32Value.newBuilder().setValue(limit))
+ .setDistanceMeasure(measure);
+
+ StructuredQuery.Builder structuredQuery = StructuredQuery.newBuilder();
+ structuredQuery.setFindNearest(findNearest.build());
+
+ return structuredQuery.build();
+ }
+
public static StructuredQuery filter(
StructuredQuery.FieldFilter.Operator operator, String path, String value) {
return filter(operator, path, string(value));
diff --git a/google-cloud-firestore/src/test/java/com/google/cloud/firestore/VectorQueryTest.java b/google-cloud-firestore/src/test/java/com/google/cloud/firestore/VectorQueryTest.java
new file mode 100644
index 000000000..717cf6e66
--- /dev/null
+++ b/google-cloud-firestore/src/test/java/com/google/cloud/firestore/VectorQueryTest.java
@@ -0,0 +1,482 @@
+/*
+ * Copyright 2024 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.google.cloud.firestore;
+
+import static com.google.cloud.firestore.LocalFirestoreHelper.COLLECTION_ID;
+import static com.google.cloud.firestore.LocalFirestoreHelper.DOCUMENT_NAME;
+import static com.google.cloud.firestore.LocalFirestoreHelper.findNearest;
+import static com.google.cloud.firestore.LocalFirestoreHelper.query;
+import static com.google.cloud.firestore.LocalFirestoreHelper.queryResponse;
+import static com.google.cloud.firestore.LocalFirestoreHelper.queryResponseWithDone;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNotEquals;
+import static org.junit.Assert.assertThrows;
+import static org.junit.Assert.assertTrue;
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.Mockito.doAnswer;
+import static org.mockito.Mockito.doReturn;
+
+import com.google.api.gax.rpc.ResponseObserver;
+import com.google.cloud.firestore.spi.v1.FirestoreRpc;
+import com.google.firestore.v1.RunQueryRequest;
+import com.google.firestore.v1.RunQueryResponse;
+import com.google.firestore.v1.StructuredQuery;
+import io.grpc.Status;
+import java.util.List;
+import java.util.concurrent.ExecutionException;
+import org.junit.Before;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.mockito.ArgumentCaptor;
+import org.mockito.Captor;
+import org.mockito.Mockito;
+import org.mockito.Spy;
+import org.mockito.junit.MockitoJUnitRunner;
+
+@RunWith(MockitoJUnitRunner.class)
+public class VectorQueryTest {
+
+ @Spy
+ private final FirestoreImpl firestoreMock =
+ new FirestoreImpl(
+ FirestoreOptions.newBuilder().setProjectId("test-project").build(),
+ Mockito.mock(FirestoreRpc.class));
+
+ @Captor private ArgumentCaptor runQuery;
+
+ @Captor private ArgumentCaptor> streamObserverCapture;
+
+ private Query queryA;
+ private Query queryB;
+
+ private QueryTest.MockClock clock;
+
+ @Before
+ public void before() {
+ clock = new QueryTest.MockClock();
+ doReturn(clock).when(firestoreMock).getClock();
+
+ queryA = firestoreMock.collection(COLLECTION_ID);
+ queryB = firestoreMock.collection(COLLECTION_ID).whereEqualTo("foo", "bar");
+ }
+
+ @Test
+ public void isEqual() {
+ Query queryA = firestoreMock.collection("collectionId").whereEqualTo("foo", "bar");
+ Query queryB = firestoreMock.collection("collectionId").whereEqualTo("foo", "bar");
+ Query queryC = firestoreMock.collection("collectionId");
+
+ assertEquals(
+ queryA.findNearest(
+ "embedding", new double[] {40, 41, 42}, 10, VectorQuery.DistanceMeasure.COSINE),
+ queryB.findNearest(
+ "embedding", new double[] {40, 41, 42}, 10, VectorQuery.DistanceMeasure.COSINE));
+
+ assertEquals(
+ queryA.findNearest(
+ "embedding", new double[] {40, 41, 42}, 10, VectorQuery.DistanceMeasure.EUCLIDEAN),
+ queryB.findNearest(
+ "embedding", new double[] {40, 41, 42}, 10, VectorQuery.DistanceMeasure.EUCLIDEAN));
+
+ assertEquals(
+ queryA.findNearest(
+ "embedding",
+ new double[] {40, 41, 42},
+ 10,
+ VectorQuery.DistanceMeasure.EUCLIDEAN,
+ VectorQueryOptions.newBuilder().setDistanceThreshold(0.125).build()),
+ queryB.findNearest(
+ "embedding",
+ new double[] {40, 41, 42},
+ 10,
+ VectorQuery.DistanceMeasure.EUCLIDEAN,
+ VectorQueryOptions.newBuilder().setDistanceThreshold(0.125).build()));
+
+ assertEquals(
+ queryA.findNearest(
+ "embedding",
+ new double[] {40, 41, 42},
+ 10,
+ VectorQuery.DistanceMeasure.EUCLIDEAN,
+ VectorQueryOptions.newBuilder()
+ .setDistanceThreshold(0.125)
+ .setDistanceResultField(FieldPath.of("foo"))
+ .build()),
+ queryB.findNearest(
+ "embedding",
+ new double[] {40, 41, 42},
+ 10,
+ VectorQuery.DistanceMeasure.EUCLIDEAN,
+ VectorQueryOptions.newBuilder()
+ .setDistanceThreshold(0.125)
+ .setDistanceResultField(FieldPath.of("foo"))
+ .build()));
+
+ assertEquals(
+ queryA.findNearest(
+ "embedding",
+ new double[] {40, 41, 42},
+ 10,
+ VectorQuery.DistanceMeasure.EUCLIDEAN,
+ VectorQueryOptions.newBuilder().setDistanceResultField("distance").build()),
+ queryB.findNearest(
+ "embedding",
+ new double[] {40, 41, 42},
+ 10,
+ VectorQuery.DistanceMeasure.EUCLIDEAN,
+ VectorQueryOptions.newBuilder()
+ .setDistanceResultField(FieldPath.of("distance"))
+ .build()));
+
+ assertNotEquals(
+ queryA.findNearest(
+ "embedding", new double[] {40, 41, 42}, 10, VectorQuery.DistanceMeasure.COSINE),
+ queryC.findNearest(
+ "embedding", new double[] {40, 41, 42}, 10, VectorQuery.DistanceMeasure.COSINE));
+
+ assertNotEquals(
+ queryA.findNearest(
+ "embedding", new double[] {40, 41, 42}, 10, VectorQuery.DistanceMeasure.COSINE),
+ queryB.findNearest(
+ "embedding", new double[] {40, 42}, 10, VectorQuery.DistanceMeasure.COSINE));
+
+ assertNotEquals(
+ queryA.findNearest(
+ "embedding", new double[] {40, 41, 42}, 10, VectorQuery.DistanceMeasure.COSINE),
+ queryB.findNearest(
+ "embedding", new double[] {40, 41, 42}, 1000, VectorQuery.DistanceMeasure.COSINE));
+
+ assertNotEquals(
+ queryA.findNearest(
+ "embedding", new double[] {40, 41, 42}, 10, VectorQuery.DistanceMeasure.COSINE),
+ queryB.findNearest(
+ "embedding", new double[] {40, 41, 42}, 10, VectorQuery.DistanceMeasure.EUCLIDEAN));
+
+ assertNotEquals(
+ queryA.findNearest(
+ "embedding",
+ new double[] {40, 41, 42},
+ 10,
+ VectorQuery.DistanceMeasure.EUCLIDEAN,
+ VectorQueryOptions.newBuilder().setDistanceThreshold(0.125).build()),
+ queryB.findNearest(
+ "embedding",
+ new double[] {40, 41, 42},
+ 10,
+ VectorQuery.DistanceMeasure.EUCLIDEAN,
+ VectorQueryOptions.newBuilder().setDistanceThreshold(1.125).build()));
+
+ assertNotEquals(
+ queryA.findNearest(
+ "embedding", new double[] {40, 41, 42}, 10, VectorQuery.DistanceMeasure.EUCLIDEAN),
+ queryB.findNearest(
+ "embedding",
+ new double[] {40, 41, 42},
+ 10,
+ VectorQuery.DistanceMeasure.EUCLIDEAN,
+ VectorQueryOptions.newBuilder().setDistanceThreshold(1.0).build()));
+
+ assertNotEquals(
+ queryA.findNearest(
+ "embedding",
+ new double[] {40, 41, 42},
+ 10,
+ VectorQuery.DistanceMeasure.EUCLIDEAN,
+ VectorQueryOptions.newBuilder().setDistanceThreshold(1.0).build()),
+ queryB.findNearest(
+ "embedding", new double[] {40, 41, 42}, 10, VectorQuery.DistanceMeasure.EUCLIDEAN));
+
+ assertNotEquals(
+ queryA.findNearest(
+ "embedding",
+ new double[] {40, 41, 42},
+ 10,
+ VectorQuery.DistanceMeasure.EUCLIDEAN,
+ VectorQueryOptions.newBuilder().setDistanceResultField("distance").build()),
+ queryB.findNearest(
+ "embedding",
+ new double[] {40, 41, 42},
+ 10,
+ VectorQuery.DistanceMeasure.EUCLIDEAN,
+ VectorQueryOptions.newBuilder().setDistanceResultField("result").build()));
+
+ assertNotEquals(
+ queryA.findNearest(
+ "embedding",
+ new double[] {40, 41, 42},
+ 10,
+ VectorQuery.DistanceMeasure.EUCLIDEAN,
+ VectorQueryOptions.newBuilder()
+ .setDistanceResultField(FieldPath.of("distance"))
+ .build()),
+ queryB.findNearest(
+ "embedding",
+ new double[] {40, 41, 42},
+ 10,
+ VectorQuery.DistanceMeasure.EUCLIDEAN,
+ VectorQueryOptions.newBuilder()
+ .setDistanceResultField(FieldPath.of("result"))
+ .build()));
+
+ assertNotEquals(
+ queryA.findNearest(
+ "embedding", new double[] {40, 41, 42}, 10, VectorQuery.DistanceMeasure.EUCLIDEAN),
+ queryB.findNearest(
+ "embedding",
+ new double[] {40, 41, 42},
+ 10,
+ VectorQuery.DistanceMeasure.EUCLIDEAN,
+ VectorQueryOptions.newBuilder().setDistanceResultField("result").build()));
+
+ assertNotEquals(
+ queryA.findNearest(
+ "embedding",
+ new double[] {40, 41, 42},
+ 10,
+ VectorQuery.DistanceMeasure.EUCLIDEAN,
+ VectorQueryOptions.newBuilder().setDistanceResultField("distance").build()),
+ queryB.findNearest(
+ "embedding", new double[] {40, 41, 42}, 10, VectorQuery.DistanceMeasure.EUCLIDEAN));
+ }
+
+ @Test
+ public void validatesInputsLimit() {
+ String expectedExceptionMessage = ".*Not a valid positive `limit` number.*";
+ Throwable exception =
+ assertThrows(
+ RuntimeException.class,
+ () ->
+ queryA.findNearest(
+ "embedding", new double[] {10, 100}, 0, VectorQuery.DistanceMeasure.EUCLIDEAN));
+ assertTrue(exception.getMessage().matches(expectedExceptionMessage));
+ }
+
+ @Test
+ public void validatesInputsVectorSize() {
+ String expectedExceptionMessage = ".*Not a valid vector size.*";
+ Throwable exception =
+ assertThrows(
+ RuntimeException.class,
+ () ->
+ queryA.findNearest(
+ "embedding", new double[0], 10, VectorQuery.DistanceMeasure.EUCLIDEAN));
+ assertTrue(exception.getMessage().matches(expectedExceptionMessage));
+ }
+
+ @Test
+ public void successfulReturnWithoutOnComplete() throws Exception {
+ doAnswer(
+ queryResponseWithDone(
+ /* callWithoutOnComplete */ true, DOCUMENT_NAME + "1", DOCUMENT_NAME + "2"))
+ .when(firestoreMock)
+ .streamRequest(runQuery.capture(), streamObserverCapture.capture(), any());
+
+ VectorQuerySnapshot snapshot =
+ queryA
+ .findNearest("vector", new double[] {1, -9.5}, 10, VectorQuery.DistanceMeasure.COSINE)
+ .get()
+ .get();
+
+ assertEquals(2, snapshot.size());
+ assertEquals(false, snapshot.isEmpty());
+ assertTrue((DOCUMENT_NAME + "1").endsWith(snapshot.getDocuments().get(0).getId()));
+ assertTrue((DOCUMENT_NAME + "2").endsWith(snapshot.getDocuments().get(1).getId()));
+ assertEquals(2, snapshot.getDocumentChanges().size());
+ assertEquals(1, snapshot.getReadTime().getSeconds());
+ assertEquals(2, snapshot.getReadTime().getNanos());
+
+ assertEquals(1, runQuery.getAllValues().size());
+ assertEquals(
+ query(
+ findNearest(
+ "vector",
+ new double[] {1, -9.5},
+ 10,
+ StructuredQuery.FindNearest.DistanceMeasure.COSINE)),
+ runQuery.getValue());
+ }
+
+ @Test
+ public void successfulReturn() throws Exception {
+ doAnswer(queryResponse(DOCUMENT_NAME + "1", DOCUMENT_NAME + "2"))
+ .when(firestoreMock)
+ .streamRequest(runQuery.capture(), streamObserverCapture.capture(), any());
+
+ VectorQuerySnapshot snapshot =
+ queryA
+ .findNearest(
+ "vector", new double[] {1, -9.5}, 10, VectorQuery.DistanceMeasure.DOT_PRODUCT)
+ .get()
+ .get();
+
+ assertEquals(2, snapshot.size());
+ assertEquals(false, snapshot.isEmpty());
+ assertTrue((DOCUMENT_NAME + "1").endsWith(snapshot.getDocuments().get(0).getId()));
+ assertTrue((DOCUMENT_NAME + "2").endsWith(snapshot.getDocuments().get(1).getId()));
+ assertEquals(2, snapshot.getDocumentChanges().size());
+ assertEquals(1, snapshot.getReadTime().getSeconds());
+ assertEquals(2, snapshot.getReadTime().getNanos());
+
+ assertEquals(1, runQuery.getAllValues().size());
+ assertEquals(
+ query(
+ findNearest(
+ "vector",
+ new double[] {1, -9.5},
+ 10,
+ StructuredQuery.FindNearest.DistanceMeasure.DOT_PRODUCT)),
+ runQuery.getValue());
+ }
+
+ @Test
+ public void handlesStreamExceptionRetryableError() {
+ final boolean[] returnError = new boolean[] {true};
+
+ doAnswer(
+ invocation -> {
+ if (returnError[0]) {
+ returnError[0] = false;
+ return queryResponse(
+ FirestoreException.forServerRejection(
+ Status.DEADLINE_EXCEEDED, "Simulated test failure"),
+ DOCUMENT_NAME + "1",
+ DOCUMENT_NAME + "2")
+ .answer(invocation);
+ } else {
+ return queryResponse(DOCUMENT_NAME + "3").answer(invocation);
+ }
+ })
+ .when(firestoreMock)
+ .streamRequest(runQuery.capture(), streamObserverCapture.capture(), any());
+
+ ExecutionException e =
+ assertThrows(
+ ExecutionException.class,
+ () ->
+ queryA
+ .findNearest(
+ "vector",
+ new double[] {1, -9.5},
+ 10,
+ VectorQuery.DistanceMeasure.DOT_PRODUCT)
+ .get()
+ .get());
+
+ // Verify the requests
+ List requests = runQuery.getAllValues();
+ assertEquals(1, requests.size());
+
+ assertTrue(requests.get(0).getStructuredQuery().hasFindNearest());
+
+ assertEquals(e.getCause().getClass(), FirestoreException.class);
+ assertTrue(e.getCause().getMessage().matches(".*Simulated test failure.*"));
+ }
+
+ @Test
+ public void handlesStreamExceptionNonRetryableError() {
+ final boolean[] returnError = new boolean[] {true};
+
+ doAnswer(
+ invocation -> {
+ if (returnError[0]) {
+ returnError[0] = false;
+ return queryResponse(
+ FirestoreException.forServerRejection(
+ Status.PERMISSION_DENIED, "Simulated test failure"))
+ .answer(invocation);
+ } else {
+ return queryResponse(DOCUMENT_NAME + "3").answer(invocation);
+ }
+ })
+ .when(firestoreMock)
+ .streamRequest(runQuery.capture(), streamObserverCapture.capture(), any());
+
+ ExecutionException e =
+ assertThrows(
+ ExecutionException.class,
+ () ->
+ queryA
+ .findNearest(
+ "vector",
+ new double[] {1, -9.5},
+ 10,
+ VectorQuery.DistanceMeasure.DOT_PRODUCT)
+ .get()
+ .get());
+
+ // Verify the requests
+ List requests = runQuery.getAllValues();
+ assertEquals(1, requests.size());
+
+ assertTrue(requests.get(0).getStructuredQuery().hasFindNearest());
+
+ assertEquals(e.getCause().getClass(), FirestoreException.class);
+ assertTrue(e.getCause().getMessage().matches(".*Simulated test failure.*"));
+ }
+
+ @Test
+ public void vectorQuerySnapshotEquality() throws Exception {
+ final int[] count = {0};
+
+ doAnswer(
+ invocation -> {
+ switch (count[0]++) {
+ case 0:
+ return queryResponse(DOCUMENT_NAME + "3", DOCUMENT_NAME + "4").answer(invocation);
+ case 1:
+ return queryResponse(DOCUMENT_NAME + "3", DOCUMENT_NAME + "4").answer(invocation);
+ case 2:
+ return queryResponse(DOCUMENT_NAME + "3", DOCUMENT_NAME + "4").answer(invocation);
+ case 3:
+ return queryResponse(DOCUMENT_NAME + "4").answer(invocation);
+ default:
+ return queryResponse().answer(invocation);
+ }
+ })
+ .when(firestoreMock)
+ .streamRequest(runQuery.capture(), streamObserverCapture.capture(), any());
+
+ VectorQuerySnapshot snapshotA =
+ queryA
+ .findNearest(
+ "vector", new double[] {1, -9.5}, 10, VectorQuery.DistanceMeasure.DOT_PRODUCT)
+ .get()
+ .get();
+ VectorQuerySnapshot snapshotB =
+ queryA
+ .findNearest(
+ "vector", new double[] {1, -9.5}, 10, VectorQuery.DistanceMeasure.DOT_PRODUCT)
+ .get()
+ .get();
+ VectorQuerySnapshot snapshotC =
+ queryB
+ .findNearest(
+ "vector", new double[] {1, -9.5}, 10, VectorQuery.DistanceMeasure.DOT_PRODUCT)
+ .get()
+ .get();
+ VectorQuerySnapshot snapshotD =
+ queryA
+ .findNearest(
+ "vector", new double[] {1, -9.5}, 10, VectorQuery.DistanceMeasure.DOT_PRODUCT)
+ .get()
+ .get();
+
+ assertEquals(snapshotA, snapshotB);
+ assertNotEquals(snapshotA, snapshotC);
+ assertNotEquals(snapshotA, snapshotD);
+ }
+}
diff --git a/google-cloud-firestore/src/test/java/com/google/cloud/firestore/it/ITQueryCountTest.java b/google-cloud-firestore/src/test/java/com/google/cloud/firestore/it/ITQueryCountTest.java
index fce1d87d6..438762e14 100644
--- a/google-cloud-firestore/src/test/java/com/google/cloud/firestore/it/ITQueryCountTest.java
+++ b/google-cloud-firestore/src/test/java/com/google/cloud/firestore/it/ITQueryCountTest.java
@@ -21,6 +21,7 @@
import static com.google.cloud.firestore.it.TestHelper.isRunningAgainstFirestoreEmulator;
import static com.google.common.truth.Truth.assertThat;
import static java.util.Collections.singletonMap;
+import static org.junit.Assert.assertThat;
import static org.junit.Assert.assertThrows;
import static org.junit.Assume.assumeFalse;
import static org.junit.Assume.assumeTrue;
diff --git a/google-cloud-firestore/src/test/java/com/google/cloud/firestore/it/ITQueryFindNearestTest.java b/google-cloud-firestore/src/test/java/com/google/cloud/firestore/it/ITQueryFindNearestTest.java
new file mode 100644
index 000000000..123e9d4a9
--- /dev/null
+++ b/google-cloud-firestore/src/test/java/com/google/cloud/firestore/it/ITQueryFindNearestTest.java
@@ -0,0 +1,933 @@
+/*
+ * Copyright 2023 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.google.cloud.firestore.it;
+
+import static com.google.cloud.firestore.LocalFirestoreHelper.autoId;
+import static com.google.cloud.firestore.LocalFirestoreHelper.map;
+import static com.google.cloud.firestore.it.TestHelper.await;
+import static com.google.common.truth.Truth.assertThat;
+import static org.junit.Assert.*;
+
+import com.google.cloud.firestore.*;
+import java.time.Duration;
+import java.util.Arrays;
+import java.util.Map;
+import javax.annotation.Nullable;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.TestName;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+
+@RunWith(JUnit4.class)
+public class ITQueryFindNearestTest extends ITBaseTest {
+ static String testId = autoId();
+
+ @Rule public TestName testName = new TestName();
+
+ private String getUniqueTestId() {
+ return testId + "-" + testName.getMethodName();
+ }
+
+ private CollectionReference testCollection() {
+ String collectionPath = "index-test-collection";
+ return firestore.collection(collectionPath);
+ }
+
+ private CollectionReference testCollectionWithDocs(Map> docs)
+ throws InterruptedException {
+ CollectionReference collection = testCollection();
+ CollectionReference writer = firestore.collection(collection.getId());
+ writeAllDocs(writer, docs);
+ return collection;
+ }
+
+ private String getUniqueDocId(String docId) {
+ return testId + docId;
+ }
+
+ public void writeAllDocs(CollectionReference collection, Map> docs)
+ throws InterruptedException {
+ for (Map.Entry> doc : docs.entrySet()) {
+ Map data = doc.getValue();
+ data.put("testId", getUniqueTestId());
+ data.put("docId", getUniqueDocId(doc.getKey()));
+ await(collection.add(data));
+ }
+ }
+
+ @Test
+ public void findNearestWithEuclideanDistance() throws Exception {
+ CollectionReference collection =
+ testCollectionWithDocs(
+ map(
+ "a", map("foo", "bar"),
+ "b", map("foo", "xxx", "embedding", FieldValue.vector(new double[] {10, 10})),
+ "c", map("foo", "bar", "embedding", FieldValue.vector(new double[] {1, 1})),
+ "d", map("foo", "bar", "embedding", FieldValue.vector(new double[] {10, 0})),
+ "e", map("foo", "bar", "embedding", FieldValue.vector(new double[] {20, 0})),
+ "f", map("foo", "bar", "embedding", FieldValue.vector(new double[] {100, 100}))));
+
+ VectorQuery vectorQuery =
+ collection
+ .whereEqualTo("foo", "bar")
+ .whereEqualTo("testId", getUniqueTestId())
+ .findNearest(
+ "embedding", new double[] {10, 10}, 3, VectorQuery.DistanceMeasure.EUCLIDEAN);
+
+ VectorQuerySnapshot snapshot = vectorQuery.get().get();
+
+ assertThat(snapshot.size()).isEqualTo(3);
+
+ assertThat(snapshot.getDocuments().get(0).get("docId")).isEqualTo(getUniqueDocId("d"));
+ assertThat(snapshot.getDocuments().get(1).get("docId")).isEqualTo(getUniqueDocId("c"));
+ assertThat(snapshot.getDocuments().get(2).get("docId")).isEqualTo(getUniqueDocId("e"));
+ }
+
+ @Test
+ public void findNearestWithEuclideanDistanceFirestoreTypeOverride() throws Exception {
+ CollectionReference collection =
+ testCollectionWithDocs(
+ map(
+ "a", map("foo", "bar"),
+ "b", map("foo", "xxx", "embedding", FieldValue.vector(new double[] {10, 10})),
+ "c", map("foo", "bar", "embedding", FieldValue.vector(new double[] {1, 1})),
+ "d", map("foo", "bar", "embedding", FieldValue.vector(new double[] {10, 0})),
+ "e", map("foo", "bar", "embedding", FieldValue.vector(new double[] {20, 0})),
+ "f", map("foo", "bar", "embedding", FieldValue.vector(new double[] {100, 100}))));
+
+ VectorQuery vectorQuery =
+ collection
+ .whereEqualTo("foo", "bar")
+ .whereEqualTo("testId", getUniqueTestId())
+ .findNearest(
+ FieldPath.of("embedding"),
+ FieldValue.vector(new double[] {10, 10}),
+ 3,
+ VectorQuery.DistanceMeasure.EUCLIDEAN);
+
+ VectorQuerySnapshot snapshot = vectorQuery.get().get();
+
+ assertThat(snapshot.size()).isEqualTo(3);
+
+ assertThat(snapshot.getDocuments().get(0).get("docId")).isEqualTo(getUniqueDocId("d"));
+ assertThat(snapshot.getDocuments().get(1).get("docId")).isEqualTo(getUniqueDocId("c"));
+ assertThat(snapshot.getDocuments().get(2).get("docId")).isEqualTo(getUniqueDocId("e"));
+ }
+
+ @Test
+ public void findNearestWithCosineDistance() throws Exception {
+ CollectionReference collection =
+ testCollectionWithDocs(
+ map(
+ "a",
+ map("foo", "bar"),
+ "b",
+ map("foo", "xxx", "embedding", FieldValue.vector(new double[] {10, 10})),
+ "c",
+ map("foo", "bar", "embedding", FieldValue.vector(new double[] {1, 1})),
+ "d",
+ map("foo", "bar", "embedding", FieldValue.vector(new double[] {20, 0})),
+ "e",
+ map("foo", "bar", "embedding", FieldValue.vector(new double[] {10, 0})),
+ "f",
+ map("foo", "bar", "embedding", FieldValue.vector(new double[] {100, 100}))));
+
+ VectorQuery vectorQuery =
+ collection
+ .whereEqualTo("foo", "bar")
+ .whereEqualTo("testId", getUniqueTestId())
+ .findNearest("embedding", new double[] {10, 10}, 3, VectorQuery.DistanceMeasure.COSINE);
+
+ VectorQuerySnapshot snapshot = vectorQuery.get().get();
+
+ assertThat(snapshot.size()).isEqualTo(3);
+
+ assertTrue(
+ Arrays.asList(getUniqueDocId("f"), getUniqueDocId("c"))
+ .contains(snapshot.getDocuments().get(0).get("docId")));
+ assertTrue(
+ Arrays.asList(getUniqueDocId("f"), getUniqueDocId("c"))
+ .contains(snapshot.getDocuments().get(1).get("docId")));
+ assertTrue(
+ Arrays.asList(getUniqueDocId("d"), getUniqueDocId("e"))
+ .contains(snapshot.getDocuments().get(2).get("docId")));
+ }
+
+ @Test
+ public void findNearestWithDotProductDistance() throws Exception {
+ CollectionReference collection =
+ testCollectionWithDocs(
+ map(
+ "a",
+ map("foo", "bar"),
+ "b",
+ map("foo", "xxx", "embedding", FieldValue.vector(new double[] {10, 10})),
+ "c",
+ map("foo", "bar", "embedding", FieldValue.vector(new double[] {1, 1})),
+ "d",
+ map("foo", "bar", "embedding", FieldValue.vector(new double[] {10, 0})),
+ "e",
+ map("foo", "bar", "embedding", FieldValue.vector(new double[] {20, 0})),
+ "f",
+ map("foo", "bar", "embedding", FieldValue.vector(new double[] {100, 100}))));
+
+ VectorQuery vectorQuery =
+ collection
+ .whereEqualTo("foo", "bar")
+ .whereEqualTo("testId", getUniqueTestId())
+ .findNearest(
+ "embedding", new double[] {10, 10}, 3, VectorQuery.DistanceMeasure.DOT_PRODUCT);
+
+ VectorQuerySnapshot snapshot = vectorQuery.get().get();
+
+ assertThat(snapshot.size()).isEqualTo(3);
+
+ assertThat(snapshot.getDocuments().get(0).get("docId")).isEqualTo(getUniqueDocId("f"));
+ assertThat(snapshot.getDocuments().get(1).get("docId")).isEqualTo(getUniqueDocId("e"));
+ assertThat(snapshot.getDocuments().get(2).get("docId")).isEqualTo(getUniqueDocId("d"));
+ }
+
+ @Test
+ public void findNearestSkipsFieldsOfWrongTypes() throws Exception {
+ CollectionReference collection =
+ testCollectionWithDocs(
+ map(
+ "a",
+ map("foo", "bar"),
+ "b",
+ map("foo", "bar", "embedding", Arrays.asList(10.0, 10.0)),
+ "c",
+ map("foo", "bar", "embedding", "not actually a vector"),
+ "d",
+ map("foo", "bar", "embedding", null),
+ "e",
+ map("foo", "bar", "embedding", FieldValue.vector(new double[] {9, 9})),
+ "f",
+ map("foo", "bar", "embedding", FieldValue.vector(new double[] {50, 50})),
+ "g",
+ map("foo", "bar", "embedding", FieldValue.vector(new double[] {100, 100}))));
+
+ VectorQuery vectorQuery =
+ collection
+ .whereEqualTo("foo", "bar")
+ .whereEqualTo("testId", getUniqueTestId())
+ .findNearest(
+ "embedding", new double[] {10, 10}, 100, VectorQuery.DistanceMeasure.EUCLIDEAN);
+
+ VectorQuerySnapshot snapshot = vectorQuery.get().get();
+
+ assertThat(snapshot.size()).isEqualTo(3);
+
+ assertThat(snapshot.getDocuments().get(0).get("docId")).isEqualTo(getUniqueDocId("e"));
+ assertThat(snapshot.getDocuments().get(1).get("docId")).isEqualTo(getUniqueDocId("f"));
+ assertThat(snapshot.getDocuments().get(2).get("docId")).isEqualTo(getUniqueDocId("g"));
+ }
+
+ @Test
+ public void findNearestIgnoresMismatchingDimensions() throws Exception {
+ CollectionReference collection =
+ testCollectionWithDocs(
+ map(
+ "a",
+ map("foo", "bar"),
+ "b",
+ map("foo", "bar", "embedding", FieldValue.vector(new double[] {10})),
+ "c",
+ map("foo", "bar", "embedding", FieldValue.vector(new double[] {9, 9})),
+ "d",
+ map("foo", "bar", "embedding", FieldValue.vector(new double[] {50, 50}))));
+
+ VectorQuery vectorQuery =
+ collection
+ .whereEqualTo("foo", "bar")
+ .whereEqualTo("testId", getUniqueTestId())
+ .findNearest(
+ "embedding", new double[] {10, 10}, 3, VectorQuery.DistanceMeasure.EUCLIDEAN);
+
+ VectorQuerySnapshot snapshot = vectorQuery.get().get();
+
+ assertThat(snapshot.size()).isEqualTo(2);
+
+ assertThat(snapshot.getDocuments().get(0).get("docId")).isEqualTo(getUniqueDocId("c"));
+ assertThat(snapshot.getDocuments().get(1).get("docId")).isEqualTo(getUniqueDocId("d"));
+ }
+
+ @Test
+ public void findNearestOnNonExistentField() throws Exception {
+ CollectionReference collection =
+ testCollectionWithDocs(
+ map(
+ "a",
+ map("foo", "bar"),
+ "b",
+ map("foo", "bar", "otherField", Arrays.asList(10.0, 10.0)),
+ "c",
+ map("foo", "bar", "otherField", "not actually a vector"),
+ "d",
+ map("foo", "bar", "otherField", null)));
+
+ VectorQuery vectorQuery =
+ collection
+ .whereEqualTo("foo", "bar")
+ .whereEqualTo("testId", getUniqueTestId())
+ .findNearest(
+ "embedding", new double[] {10, 10}, 3, VectorQuery.DistanceMeasure.EUCLIDEAN);
+
+ VectorQuerySnapshot snapshot = vectorQuery.get().get();
+
+ assertThat(snapshot.size()).isEqualTo(0);
+ }
+
+ @Test
+ public void findNearestOnVectorNestedInAMap() throws Exception {
+ CollectionReference collection =
+ testCollectionWithDocs(
+ map(
+ "a",
+ map("nested", map("foo", "bar")),
+ "b",
+ map(
+ "nested",
+ map("foo", "xxx", "embedding", FieldValue.vector(new double[] {10, 10}))),
+ "c",
+ map(
+ "nested",
+ map("foo", "bar", "embedding", FieldValue.vector(new double[] {1, 1}))),
+ "d",
+ map(
+ "nested",
+ map("foo", "bar", "embedding", FieldValue.vector(new double[] {10, 0}))),
+ "e",
+ map(
+ "nested",
+ map("foo", "bar", "embedding", FieldValue.vector(new double[] {20, 0}))),
+ "f",
+ map(
+ "nested",
+ map("foo", "bar", "embedding", FieldValue.vector(new double[] {100, 100})))));
+
+ VectorQuery vectorQuery =
+ collection
+ .whereEqualTo("testId", getUniqueTestId())
+ .findNearest(
+ "nested.embedding",
+ new double[] {10, 10},
+ 3,
+ VectorQuery.DistanceMeasure.EUCLIDEAN);
+
+ VectorQuerySnapshot snapshot = vectorQuery.get().get();
+
+ assertThat(snapshot.size()).isEqualTo(3);
+
+ assertThat(snapshot.getDocuments().get(0).get("docId")).isEqualTo(getUniqueDocId("b"));
+ assertThat(snapshot.getDocuments().get(1).get("docId")).isEqualTo(getUniqueDocId("d"));
+ assertThat(snapshot.getDocuments().get(2).get("docId")).isEqualTo(getUniqueDocId("c"));
+ }
+
+ @Test
+ public void findNearestWithSelectToExcludeVectorDataInResponse() throws Exception {
+ CollectionReference collection =
+ testCollectionWithDocs(
+ map(
+ "a",
+ map("foo", 1),
+ "b",
+ map("foo", 2, "embedding", FieldValue.vector(new double[] {10, 10})),
+ "c",
+ map("foo", 3, "embedding", FieldValue.vector(new double[] {1, 1})),
+ "d",
+ map("foo", 4, "embedding", FieldValue.vector(new double[] {10, 0})),
+ "e",
+ map("foo", 5, "embedding", FieldValue.vector(new double[] {20, 0})),
+ "f",
+ map("foo", 6, "embedding", FieldValue.vector(new double[] {100, 100}))));
+
+ VectorQuery vectorQuery =
+ collection
+ .whereIn("foo", Arrays.asList(1, 2, 3, 4, 5, 6))
+ .whereEqualTo("testId", getUniqueTestId())
+ .select("foo", "docId")
+ .findNearest(
+ "embedding", new double[] {10, 10}, 10, VectorQuery.DistanceMeasure.EUCLIDEAN);
+
+ VectorQuerySnapshot snapshot = vectorQuery.get().get();
+
+ assertThat(snapshot.size()).isEqualTo(5);
+
+ assertThat(snapshot.getDocuments().get(0).get("docId")).isEqualTo(getUniqueDocId("b"));
+ assertThat(snapshot.getDocuments().get(1).get("docId")).isEqualTo(getUniqueDocId("d"));
+ assertThat(snapshot.getDocuments().get(2).get("docId")).isEqualTo(getUniqueDocId("c"));
+ assertThat(snapshot.getDocuments().get(3).get("docId")).isEqualTo(getUniqueDocId("e"));
+ assertThat(snapshot.getDocuments().get(4).get("docId")).isEqualTo(getUniqueDocId("f"));
+
+ for (QueryDocumentSnapshot doc : snapshot.getDocuments()) {
+ assertThat(doc.get("embedding")).isNull();
+ }
+ }
+
+ @Test
+ public void findNearestLimits() throws Exception {
+ double[] embeddingVector = new double[2048];
+ double[] queryVector = new double[2048];
+ for (int i = 0; i < 2048; i++) {
+ embeddingVector[i] = i + 1;
+ queryVector[i] = i - 1;
+ }
+
+ CollectionReference collection =
+ testCollectionWithDocs(map("a", map("embedding", FieldValue.vector(embeddingVector))));
+
+ VectorQuery vectorQuery =
+ collection
+ .whereEqualTo("testId", getUniqueTestId())
+ .findNearest("embedding", queryVector, 1000, VectorQuery.DistanceMeasure.EUCLIDEAN);
+
+ VectorQuerySnapshot snapshot = vectorQuery.get().get();
+
+ assertThat(snapshot.size()).isEqualTo(1);
+
+ assertThat(((VectorValue) snapshot.getDocuments().get(0).get("embedding")).toArray())
+ .isEqualTo(embeddingVector);
+ }
+
+ @Test
+ public void requestingComputedCosineDistance() throws Exception {
+ CollectionReference collection =
+ testCollectionWithDocs(
+ map(
+ "1",
+ map("foo", "bar"),
+ "2",
+ map("foo", "bar", "embedding", FieldValue.vector(new double[] {1, 0})),
+ "3",
+ map("foo", "bar", "embedding", FieldValue.vector(new double[] {0, 1})),
+ "4",
+ map("foo", "bar", "embedding", FieldValue.vector(new double[] {0, -0.1})),
+ "5",
+ map("foo", "bar", "embedding", FieldValue.vector(new double[] {-1, 0}))));
+
+ VectorQuery vectorQuery =
+ collection
+ .whereEqualTo("testId", getUniqueTestId())
+ .findNearest(
+ "embedding",
+ new double[] {1, 0},
+ 5,
+ VectorQuery.DistanceMeasure.COSINE,
+ VectorQueryOptions.newBuilder().setDistanceResultField("distance").build());
+
+ VectorQuerySnapshot snapshot = vectorQuery.get().get();
+
+ assertThat(snapshot.size()).isEqualTo(4);
+
+ assertThat(snapshot.getDocuments().get(0).get("docId")).isEqualTo(getUniqueDocId("2"));
+ assertThat(snapshot.getDocuments().get(0).getDouble("distance")).isEqualTo(0);
+
+ assertThat(snapshot.getDocuments().get(1).getDouble("distance")).isEqualTo(1);
+ assertThat(snapshot.getDocuments().get(2).getDouble("distance")).isEqualTo(1);
+
+ assertThat(snapshot.getDocuments().get(3).get("docId")).isEqualTo(getUniqueDocId("5"));
+ assertThat(snapshot.getDocuments().get(3).getDouble("distance")).isEqualTo(2);
+ }
+
+ @Test
+ public void requestingComputedEuclideanDistance() throws Exception {
+ CollectionReference collection =
+ testCollectionWithDocs(
+ map(
+ "1",
+ map("foo", "bar"),
+ "2",
+ map("foo", "bar", "embedding", FieldValue.vector(new double[] {2, 0})),
+ "3",
+ map("foo", "bar", "embedding", FieldValue.vector(new double[] {1, 100})),
+ "4",
+ map("foo", "bar", "embedding", FieldValue.vector(new double[] {1, -0.1})),
+ "5",
+ map("foo", "bar", "embedding", FieldValue.vector(new double[] {4, 4}))));
+
+ VectorQuery vectorQuery =
+ collection
+ .whereEqualTo("testId", getUniqueTestId())
+ .findNearest(
+ "embedding",
+ new double[] {1, 0},
+ 5,
+ VectorQuery.DistanceMeasure.EUCLIDEAN,
+ VectorQueryOptions.newBuilder().setDistanceResultField("distance").build());
+
+ VectorQuerySnapshot snapshot = vectorQuery.get().get();
+
+ assertThat(snapshot.size()).isEqualTo(4);
+
+ assertThat(snapshot.getDocuments().get(0).get("docId")).isEqualTo(getUniqueDocId("4"));
+ assertThat(snapshot.getDocuments().get(0).getDouble("distance")).isEqualTo(0.1);
+
+ assertThat(snapshot.getDocuments().get(1).get("docId")).isEqualTo(getUniqueDocId("2"));
+ assertThat(snapshot.getDocuments().get(1).getDouble("distance")).isEqualTo(1);
+
+ assertThat(snapshot.getDocuments().get(2).get("docId")).isEqualTo(getUniqueDocId("5"));
+ assertThat(snapshot.getDocuments().get(2).getDouble("distance")).isEqualTo(5);
+
+ assertThat(snapshot.getDocuments().get(3).get("docId")).isEqualTo(getUniqueDocId("3"));
+ assertThat(snapshot.getDocuments().get(3).getDouble("distance")).isEqualTo(100);
+ }
+
+ @Test
+ public void requestingComputedDotProductDistance() throws Exception {
+ CollectionReference collection =
+ testCollectionWithDocs(
+ map(
+ "1",
+ map("foo", "bar"),
+ "2",
+ map("foo", "bar", "embedding", FieldValue.vector(new double[] {2, 0})),
+ "3",
+ map("foo", "bar", "embedding", FieldValue.vector(new double[] {1, 100})),
+ "4",
+ map("foo", "bar", "embedding", FieldValue.vector(new double[] {-20, 0})),
+ "5",
+ map("foo", "bar", "embedding", FieldValue.vector(new double[] {0.1, 4}))));
+
+ VectorQuery vectorQuery =
+ collection
+ .whereEqualTo("testId", getUniqueTestId())
+ .findNearest(
+ "embedding",
+ new double[] {1, 0},
+ 5,
+ VectorQuery.DistanceMeasure.DOT_PRODUCT,
+ VectorQueryOptions.newBuilder().setDistanceResultField("distance").build());
+
+ VectorQuerySnapshot snapshot = vectorQuery.get().get();
+
+ assertThat(snapshot.size()).isEqualTo(4);
+
+ assertThat(snapshot.getDocuments().get(0).get("docId")).isEqualTo(getUniqueDocId("2"));
+ assertThat(snapshot.getDocuments().get(0).getDouble("distance")).isEqualTo(2);
+
+ assertThat(snapshot.getDocuments().get(1).get("docId")).isEqualTo(getUniqueDocId("3"));
+ assertThat(snapshot.getDocuments().get(1).getDouble("distance")).isEqualTo(1);
+
+ assertThat(snapshot.getDocuments().get(2).get("docId")).isEqualTo(getUniqueDocId("5"));
+ assertThat(snapshot.getDocuments().get(2).getDouble("distance")).isEqualTo(0.1);
+
+ assertThat(snapshot.getDocuments().get(3).get("docId")).isEqualTo(getUniqueDocId("4"));
+ assertThat(snapshot.getDocuments().get(3).getDouble("distance")).isEqualTo(-20);
+ }
+
+ @Test
+ public void overwritesDistanceResultFieldOnConflict() throws Exception {
+ CollectionReference collection =
+ testCollectionWithDocs(
+ map(
+ "1",
+ map(
+ "foo",
+ "bar",
+ "embedding",
+ FieldValue.vector(new double[] {0, 1}),
+ "distance",
+ "100 miles")));
+
+ VectorQuery vectorQuery =
+ collection
+ .whereEqualTo("testId", getUniqueTestId())
+ .findNearest(
+ "embedding",
+ new double[] {1, 0},
+ 5,
+ VectorQuery.DistanceMeasure.COSINE,
+ VectorQueryOptions.newBuilder().setDistanceResultField("distance").build());
+
+ VectorQuerySnapshot snapshot = vectorQuery.get().get();
+
+ assertThat(snapshot.size()).isEqualTo(1);
+
+ assertThat(snapshot.getDocuments().get(0).get("docId")).isEqualTo(getUniqueDocId("1"));
+ assertThat(snapshot.getDocuments().get(0).getDouble("distance")).isEqualTo(1);
+ }
+
+ @Test
+ public void requestingComputedDistanceInSelectQueries() throws Exception {
+ CollectionReference collection =
+ testCollectionWithDocs(
+ map(
+ "1",
+ map("foo", "bar"),
+ "2",
+ map("foo", "bar", "embedding", FieldValue.vector(new double[] {1, 0})),
+ "3",
+ map("foo", "bar", "embedding", FieldValue.vector(new double[] {0, 1})),
+ "4",
+ map("foo", "bar", "embedding", FieldValue.vector(new double[] {0, -0.1})),
+ "5",
+ map("foo", "bar", "embedding", FieldValue.vector(new double[] {-1, 0}))));
+
+ VectorQuery vectorQuery =
+ collection
+ .whereEqualTo("testId", getUniqueTestId())
+ .select("embedding", "distance", "docId")
+ .findNearest(
+ "embedding",
+ new double[] {1, 0},
+ 5,
+ VectorQuery.DistanceMeasure.COSINE,
+ VectorQueryOptions.newBuilder().setDistanceResultField("distance").build());
+
+ VectorQuerySnapshot snapshot = vectorQuery.get().get();
+
+ assertThat(snapshot.size()).isEqualTo(4);
+
+ assertThat(snapshot.getDocuments().get(0).get("docId")).isEqualTo(getUniqueDocId("2"));
+ assertThat(snapshot.getDocuments().get(0).getDouble("distance")).isEqualTo(0);
+
+ assertThat(snapshot.getDocuments().get(1).getDouble("distance")).isEqualTo(1);
+ assertThat(snapshot.getDocuments().get(2).getDouble("distance")).isEqualTo(1);
+
+ assertThat(snapshot.getDocuments().get(3).get("docId")).isEqualTo(getUniqueDocId("5"));
+ assertThat(snapshot.getDocuments().get(3).getDouble("distance")).isEqualTo(2);
+ }
+
+ @Test
+ public void queryingWithDistanceThresholdUsingCosineDistance() throws Exception {
+ CollectionReference collection =
+ testCollectionWithDocs(
+ map(
+ "1",
+ map("foo", "bar"),
+ "2",
+ map("foo", "bar", "embedding", FieldValue.vector(new double[] {1, 0})),
+ "3",
+ map("foo", "bar", "embedding", FieldValue.vector(new double[] {1, 1})),
+ "4",
+ map("foo", "bar", "embedding", FieldValue.vector(new double[] {0, -0.1})),
+ "5",
+ map("foo", "bar", "embedding", FieldValue.vector(new double[] {-1, 0}))));
+
+ VectorQuery vectorQuery =
+ collection
+ .whereEqualTo("testId", getUniqueTestId())
+ .findNearest(
+ "embedding",
+ new double[] {1, 0},
+ 5,
+ VectorQuery.DistanceMeasure.COSINE,
+ VectorQueryOptions.newBuilder().setDistanceThreshold(1.0).build());
+
+ VectorQuerySnapshot snapshot = vectorQuery.get().get();
+
+ assertThat(snapshot.size()).isEqualTo(3);
+
+ assertThat(snapshot.getDocuments().get(0).get("docId")).isEqualTo(getUniqueDocId("2"));
+ assertThat(snapshot.getDocuments().get(1).get("docId")).isEqualTo(getUniqueDocId("3"));
+ assertThat(snapshot.getDocuments().get(2).get("docId")).isEqualTo(getUniqueDocId("4"));
+ }
+
+ @Test
+ public void queryingWithDistanceThresholdUsingEuclideanDistance() throws Exception {
+ CollectionReference collection =
+ testCollectionWithDocs(
+ map(
+ "1",
+ map("foo", "bar"),
+ "2",
+ map("foo", "bar", "embedding", FieldValue.vector(new double[] {2, 0})),
+ "3",
+ map("foo", "bar", "embedding", FieldValue.vector(new double[] {1, 100})),
+ "4",
+ map("foo", "bar", "embedding", FieldValue.vector(new double[] {1, -0.1})),
+ "5",
+ map("foo", "bar", "embedding", FieldValue.vector(new double[] {4, 4}))));
+
+ VectorQuery vectorQuery =
+ collection
+ .whereEqualTo("testId", getUniqueTestId())
+ .findNearest(
+ "embedding",
+ new double[] {1, 0},
+ 5,
+ VectorQuery.DistanceMeasure.EUCLIDEAN,
+ VectorQueryOptions.newBuilder().setDistanceThreshold(5.0).build());
+
+ VectorQuerySnapshot snapshot = vectorQuery.get().get();
+
+ assertThat(snapshot.size()).isEqualTo(3);
+
+ assertThat(snapshot.getDocuments().get(0).get("docId")).isEqualTo(getUniqueDocId("4"));
+ assertThat(snapshot.getDocuments().get(1).get("docId")).isEqualTo(getUniqueDocId("2"));
+ assertThat(snapshot.getDocuments().get(2).get("docId")).isEqualTo(getUniqueDocId("5"));
+ }
+
+ @Test
+ public void queryingWithDistanceThresholdUsingDotProductDistance() throws Exception {
+ CollectionReference collection =
+ testCollectionWithDocs(
+ map(
+ "1",
+ map("foo", "bar"),
+ "2",
+ map("foo", "bar", "embedding", FieldValue.vector(new double[] {2, 0})),
+ "3",
+ map("foo", "bar", "embedding", FieldValue.vector(new double[] {1, 100})),
+ "4",
+ map("foo", "bar", "embedding", FieldValue.vector(new double[] {-20, 0})),
+ "5",
+ map("foo", "bar", "embedding", FieldValue.vector(new double[] {0.1, 4}))));
+
+ VectorQuery vectorQuery =
+ collection
+ .whereEqualTo("testId", getUniqueTestId())
+ .findNearest(
+ "embedding",
+ new double[] {1, 0},
+ 5,
+ VectorQuery.DistanceMeasure.DOT_PRODUCT,
+ VectorQueryOptions.newBuilder().setDistanceThreshold(1.0).build());
+
+ VectorQuerySnapshot snapshot = vectorQuery.get().get();
+
+ assertThat(snapshot.size()).isEqualTo(2);
+
+ assertThat(snapshot.getDocuments().get(0).get("docId")).isEqualTo(getUniqueDocId("2"));
+ assertThat(snapshot.getDocuments().get(1).get("docId")).isEqualTo(getUniqueDocId("3"));
+ }
+
+ @Test
+ public void queryWithDistanceResultFieldAndDistanceThreshold() throws Exception {
+ CollectionReference collection =
+ testCollectionWithDocs(
+ map(
+ "1",
+ map("foo", "bar"),
+ "2",
+ map("foo", "bar", "embedding", FieldValue.vector(new double[] {2, 0})),
+ "3",
+ map("foo", "bar", "embedding", FieldValue.vector(new double[] {1, 100})),
+ "4",
+ map("foo", "bar", "embedding", FieldValue.vector(new double[] {-20, 0})),
+ "5",
+ map("foo", "bar", "embedding", FieldValue.vector(new double[] {0.1, 4}))));
+
+ VectorQuery vectorQuery =
+ collection
+ .whereEqualTo("testId", getUniqueTestId())
+ .findNearest(
+ "embedding",
+ new double[] {1, 0},
+ 5,
+ VectorQuery.DistanceMeasure.DOT_PRODUCT,
+ VectorQueryOptions.newBuilder()
+ .setDistanceThreshold(0.11)
+ .setDistanceResultField("foo")
+ .build());
+
+ VectorQuerySnapshot snapshot = vectorQuery.get().get();
+
+ assertThat(snapshot.size()).isEqualTo(2);
+
+ assertThat(snapshot.getDocuments().get(0).get("docId")).isEqualTo(getUniqueDocId("2"));
+ assertThat(snapshot.getDocuments().get(0).getDouble("foo")).isEqualTo(2);
+
+ assertThat(snapshot.getDocuments().get(1).get("docId")).isEqualTo(getUniqueDocId("3"));
+ assertThat(snapshot.getDocuments().get(1).getDouble("foo")).isEqualTo(1);
+ }
+
+ @Test
+ public void queryWithDistanceResultFieldAndDistanceThresholdWithFirestoreTypes()
+ throws Exception {
+ CollectionReference collection =
+ testCollectionWithDocs(
+ map(
+ "1",
+ map("foo", "bar"),
+ "2",
+ map("foo", "bar", "embedding", FieldValue.vector(new double[] {2, 0})),
+ "3",
+ map("foo", "bar", "embedding", FieldValue.vector(new double[] {1, 100})),
+ "4",
+ map("foo", "bar", "embedding", FieldValue.vector(new double[] {-20, 0})),
+ "5",
+ map("foo", "bar", "embedding", FieldValue.vector(new double[] {0.1, 4}))));
+
+ VectorQuery vectorQuery =
+ collection
+ .whereEqualTo("testId", getUniqueTestId())
+ .findNearest(
+ FieldPath.of("embedding"),
+ FieldValue.vector(new double[] {1, 0}),
+ 5,
+ VectorQuery.DistanceMeasure.DOT_PRODUCT,
+ VectorQueryOptions.newBuilder()
+ .setDistanceThreshold(0.11)
+ .setDistanceResultField("foo")
+ .build());
+
+ VectorQuerySnapshot snapshot = vectorQuery.get().get();
+
+ assertThat(snapshot.size()).isEqualTo(2);
+
+ assertThat(snapshot.getDocuments().get(0).get("docId")).isEqualTo(getUniqueDocId("2"));
+ assertThat(snapshot.getDocuments().get(0).getDouble("foo")).isEqualTo(2);
+
+ assertThat(snapshot.getDocuments().get(1).get("docId")).isEqualTo(getUniqueDocId("3"));
+ assertThat(snapshot.getDocuments().get(1).getDouble("foo")).isEqualTo(1);
+ }
+
+ @Test
+ public void willNotExceedLimitEvenIfThereAreMoreResultsMoreSimilarThanDistanceThreshold()
+ throws Exception {
+ CollectionReference collection =
+ testCollectionWithDocs(
+ map(
+ "1",
+ map("foo", "bar"),
+ "2",
+ map("foo", "bar", "embedding", FieldValue.vector(new double[] {2, 0})),
+ "3",
+ map("foo", "bar", "embedding", FieldValue.vector(new double[] {1, 100})),
+ "4",
+ map("foo", "bar", "embedding", FieldValue.vector(new double[] {-20, 0})),
+ "5",
+ map("foo", "bar", "embedding", FieldValue.vector(new double[] {0.1, 4}))));
+
+ VectorQuery vectorQuery =
+ collection
+ .whereEqualTo("testId", getUniqueTestId())
+ .findNearest(
+ "embedding",
+ new double[] {1, 0},
+ 2, // limit set to 2
+ VectorQuery.DistanceMeasure.DOT_PRODUCT,
+ VectorQueryOptions.newBuilder().setDistanceThreshold(0.0).build());
+
+ VectorQuerySnapshot snapshot = vectorQuery.get().get();
+
+ assertThat(snapshot.size()).isEqualTo(2);
+
+ assertThat(snapshot.getDocuments().get(0).get("docId")).isEqualTo(getUniqueDocId("2"));
+ assertThat(snapshot.getDocuments().get(1).get("docId")).isEqualTo(getUniqueDocId("3"));
+ }
+
+ @Test
+ public void testVectorQueryPlan() throws Exception {
+ CollectionReference collection =
+ testCollectionWithDocs(
+ map(
+ "1",
+ map("foo", "bar"),
+ "2",
+ map("foo", "bar", "embedding", FieldValue.vector(new double[] {2, 0})),
+ "3",
+ map("foo", "bar", "embedding", FieldValue.vector(new double[] {1, 100})),
+ "4",
+ map("foo", "bar", "embedding", FieldValue.vector(new double[] {-20, 0})),
+ "5",
+ map("foo", "bar", "embedding", FieldValue.vector(new double[] {0.1, 4}))));
+
+ VectorQuery vectorQuery =
+ collection
+ .whereEqualTo("testId", getUniqueTestId())
+ .findNearest(
+ "embedding", new double[] {1, 0}, 5, VectorQuery.DistanceMeasure.DOT_PRODUCT);
+
+ ExplainResults explainResults =
+ vectorQuery.explain(ExplainOptions.builder().setAnalyze(false).build()).get();
+
+ @Nullable VectorQuerySnapshot snapshot = explainResults.getSnapshot();
+ assertThat(snapshot).isNull();
+
+ ExplainMetrics metrics = explainResults.getMetrics();
+ assertThat(metrics).isNotNull();
+
+ PlanSummary planSummary = metrics.getPlanSummary();
+ assertThat(planSummary).isNotNull();
+ assertThat(planSummary.getIndexesUsed()).isNotEmpty();
+
+ ExecutionStats stats = metrics.getExecutionStats();
+ assertThat(stats).isNull();
+ }
+
+ @Test
+ public void testVectorQueryProfile() throws Exception {
+ CollectionReference collection =
+ testCollectionWithDocs(
+ map(
+ "1",
+ map("foo", "bar"),
+ "2",
+ map("foo", "bar", "embedding", FieldValue.vector(new double[] {2, 0})),
+ "3",
+ map("foo", "bar", "embedding", FieldValue.vector(new double[] {1, 100})),
+ "4",
+ map("foo", "bar", "embedding", FieldValue.vector(new double[] {-20, 0})),
+ "5",
+ map("foo", "bar", "embedding", FieldValue.vector(new double[] {0.1, 4}))));
+
+ VectorQuery vectorQuery =
+ collection
+ .whereEqualTo("testId", getUniqueTestId())
+ .findNearest(
+ "embedding", new double[] {1, 0}, 5, VectorQuery.DistanceMeasure.DOT_PRODUCT);
+
+ ExplainResults explainResults =
+ vectorQuery.explain(ExplainOptions.builder().setAnalyze(true).build()).get();
+
+ @Nullable VectorQuerySnapshot snapshot = explainResults.getSnapshot();
+ assertThat(snapshot).isNotNull();
+ assertThat(snapshot.size()).isEqualTo(4);
+
+ ExplainMetrics metrics = explainResults.getMetrics();
+ assertThat(metrics).isNotNull();
+
+ PlanSummary planSummary = metrics.getPlanSummary();
+ assertThat(planSummary).isNotNull();
+ assertThat(planSummary.getIndexesUsed()).isNotEmpty();
+
+ ExecutionStats stats = metrics.getExecutionStats();
+ assertThat(stats).isNotNull();
+ assertThat(stats.getDebugStats()).isNotEmpty();
+ assertThat(stats.getReadOperations()).isEqualTo(5);
+ assertThat(stats.getResultsReturned()).isEqualTo(4);
+ assertThat(stats.getExecutionDuration()).isGreaterThan(Duration.ZERO);
+ }
+
+ @Test
+ public void vectorQuerySnapshotReturnsVectorQuery() throws Exception {
+ CollectionReference collection =
+ testCollectionWithDocs(
+ map(
+ "a", map("foo", "bar"),
+ "b", map("foo", "xxx", "embedding", FieldValue.vector(new double[] {10, 10})),
+ "c", map("foo", "bar", "embedding", FieldValue.vector(new double[] {1, 1})),
+ "d", map("foo", "bar", "embedding", FieldValue.vector(new double[] {10, 0})),
+ "e", map("foo", "bar", "embedding", FieldValue.vector(new double[] {20, 0})),
+ "f", map("foo", "bar", "embedding", FieldValue.vector(new double[] {100, 100}))));
+
+ VectorQuery vectorQuery =
+ collection
+ .whereEqualTo("foo", "bar")
+ .whereEqualTo("testId", getUniqueTestId())
+ .findNearest(
+ "embedding", new double[] {10, 10}, 3, VectorQuery.DistanceMeasure.EUCLIDEAN);
+
+ VectorQuerySnapshot snapshot = vectorQuery.get().get();
+
+ assertTrue(snapshot.getQuery() == vectorQuery);
+ }
+}