diff --git a/README.md b/README.md index d7af89bdf..5ccc2f9f6 100644 --- a/README.md +++ b/README.md @@ -42,7 +42,7 @@ If you are using Maven without the BOM, add this to your dependencies: com.google.cloud google-cloud-firestore - 3.26.4 + 3.26.5 ``` diff --git a/google-cloud-firestore/clirr-ignored-differences.xml b/google-cloud-firestore/clirr-ignored-differences.xml index 3c42178c0..fd6e9e5b0 100644 --- a/google-cloud-firestore/clirr-ignored-differences.xml +++ b/google-cloud-firestore/clirr-ignored-differences.xml @@ -306,4 +306,12 @@ com/google/cloud/firestore/encoding/CustomClassMapper * + + + + 7002 + com/google/cloud/firestore/QuerySnapshot + com.google.cloud.firestore.Query getQuery() + + diff --git a/google-cloud-firestore/src/main/java/com/google/cloud/firestore/GenericQuerySnapshot.java b/google-cloud-firestore/src/main/java/com/google/cloud/firestore/GenericQuerySnapshot.java new file mode 100644 index 000000000..a7b02e22a --- /dev/null +++ b/google-cloud-firestore/src/main/java/com/google/cloud/firestore/GenericQuerySnapshot.java @@ -0,0 +1,165 @@ +/* + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.cloud.firestore; + +import com.google.cloud.Timestamp; +import com.google.cloud.firestore.encoding.CustomClassMapper; +import com.google.common.collect.ImmutableList; +import java.util.ArrayList; +import java.util.Collections; +import java.util.Iterator; +import java.util.List; +import java.util.Objects; +import javax.annotation.Nonnull; + +/** + * Abstract. A GenericQuerySnapshot represents the results of a query that returns documents. It can + * contain zero or more DocumentSnapshot objects. + */ +public abstract class GenericQuerySnapshot implements Iterable { + protected final QueryT query; + protected final Timestamp readTime; + + private List documentChanges; + private final List documents; + + // Elevated access level for mocking. + protected GenericQuerySnapshot( + QueryT query, + Timestamp readTime, + @Nonnull final List documents, + final List documentChanges) { + this.query = query; + this.readTime = readTime; + this.documentChanges = + documentChanges != null ? Collections.unmodifiableList(documentChanges) : documentChanges; + this.documents = Collections.unmodifiableList(documents); + } + + /** + * Returns the query for the snapshot. + * + * @return The backing query that produced this snapshot. + */ + @Nonnull + public QueryT getQuery() { + return query; + } + + /** + * Returns the time at which this snapshot was read. + * + * @return The read time of this snapshot. + */ + @Nonnull + public Timestamp getReadTime() { + return readTime; + } + + /** + * Returns the documents in this QuerySnapshot as a List in order of the query. + * + * @return The list of documents. + */ + @Nonnull + public List getDocuments() { + return this.documents; + } + + /** Returns true if there are no documents in the QuerySnapshot. */ + public boolean isEmpty() { + return this.size() == 0; + } + + @Nonnull + public Iterator iterator() { + return getDocuments().iterator(); + } + + /** + * Returns the contents of the documents in the QuerySnapshot, converted to the provided class, as + * a list. + * + * @param clazz The POJO type used to convert the documents in the list. + */ + @Nonnull + public List toObjects(@Nonnull Class clazz) { + List documents = getDocuments(); + List results = new ArrayList<>(documents.size()); + for (DocumentSnapshot documentSnapshot : documents) { + results.add( + CustomClassMapper.convertToCustomClass( + documentSnapshot.getData(), clazz, documentSnapshot.getReference())); + } + + return results; + } + + /** + * Returns the list of documents that changed since the last snapshot. If it's the first snapshot + * all documents will be in the list as added changes. + * + * @return The list of documents that changed since the last snapshot. + */ + @Nonnull + public List getDocumentChanges() { + if (documentChanges == null) { + synchronized (documents) { + if (documentChanges == null) { + int size = documents.size(); + ImmutableList.Builder builder = + ImmutableList.builderWithExpectedSize(size); + for (int i = 0; i < size; ++i) { + builder.add(new DocumentChange(documents.get(i), DocumentChange.Type.ADDED, -1, i)); + } + documentChanges = builder.build(); + } + } + } + + return documentChanges; + } + + /** Returns the number of DocumentSnapshots in this snapshot. */ + public int size() { + return getDocuments().size(); + } + + /** + * Tests for equality with this object. + * + * @param o is tested for equality with this object. + * @return `true` if equal, otherwise `false` + */ + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + GenericQuerySnapshot that = (GenericQuerySnapshot) o; + return Objects.equals(query, that.query) + && Objects.equals(this.getDocumentChanges(), that.getDocumentChanges()) + && Objects.equals(this.getDocuments(), that.getDocuments()); + } + + @Override + public int hashCode() { + return Objects.hash(query, this.getDocumentChanges(), this.getDocuments()); + } +} diff --git a/google-cloud-firestore/src/main/java/com/google/cloud/firestore/Query.java b/google-cloud-firestore/src/main/java/com/google/cloud/firestore/Query.java index a794b6a63..e66393b84 100644 --- a/google-cloud-firestore/src/main/java/com/google/cloud/firestore/Query.java +++ b/google-cloud-firestore/src/main/java/com/google/cloud/firestore/Query.java @@ -16,8 +16,6 @@ package com.google.cloud.firestore; -import static com.google.cloud.firestore.telemetry.TraceUtil.*; -import static com.google.common.collect.Lists.reverse; import static com.google.firestore.v1.StructuredQuery.FieldFilter.Operator.ARRAY_CONTAINS; import static com.google.firestore.v1.StructuredQuery.FieldFilter.Operator.ARRAY_CONTAINS_ANY; import static com.google.firestore.v1.StructuredQuery.FieldFilter.Operator.EQUAL; @@ -33,19 +31,12 @@ import com.google.api.core.InternalExtensionOnly; import com.google.api.core.SettableApiFuture; import com.google.api.gax.rpc.ApiStreamObserver; -import com.google.api.gax.rpc.ResponseObserver; -import com.google.api.gax.rpc.StatusCode; -import com.google.api.gax.rpc.StreamController; import com.google.auto.value.AutoValue; import com.google.cloud.Timestamp; import com.google.cloud.firestore.Query.QueryOptions.Builder; import com.google.cloud.firestore.encoding.CustomClassMapper; -import com.google.cloud.firestore.telemetry.TraceUtil; -import com.google.cloud.firestore.telemetry.TraceUtil.Scope; -import com.google.cloud.firestore.v1.FirestoreSettings; import com.google.common.base.Preconditions; import com.google.common.collect.ImmutableList; -import com.google.common.collect.ImmutableMap; import com.google.firestore.bundle.BundledQuery; import com.google.firestore.v1.Cursor; import com.google.firestore.v1.Document; @@ -61,7 +52,6 @@ import com.google.firestore.v1.Value; import com.google.protobuf.ByteString; import com.google.protobuf.Int32Value; -import io.grpc.Status; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; @@ -70,27 +60,22 @@ import java.util.Iterator; import java.util.List; import java.util.Objects; -import java.util.Set; import java.util.SortedSet; import java.util.TreeSet; import java.util.concurrent.Executor; -import java.util.concurrent.atomic.AtomicReference; import java.util.logging.Logger; import javax.annotation.Nonnull; import javax.annotation.Nullable; -import org.threeten.bp.Duration; /** * A Query which you can read or listen to. You can also construct refined Query objects by adding * filters and ordering. */ @InternalExtensionOnly -public class Query { - - static final Comparator DOCUMENT_ID_COMPARATOR = Query::compareDocumentId; - final FirestoreRpcContext rpcContext; - final QueryOptions options; +public class Query extends StreamableQuery { + static final Comparator DOCUMENT_ID_COMPARATOR = + QueryDocumentSnapshot::compareDocumentId; private static final Logger LOGGER = Logger.getLogger(Query.class.getName()); /** The direction of a sort. */ @@ -481,18 +466,12 @@ abstract static class Builder { } protected Query(FirestoreRpcContext rpcContext, QueryOptions queryOptions) { - this.rpcContext = rpcContext; - this.options = queryOptions; + super(rpcContext, queryOptions); } - /** - * Gets the Firestore instance associated with this query. - * - * @return The Firestore instance associated with this query. - */ - @Nonnull - public Firestore getFirestore() { - return rpcContext.getFirestore(); + @Override + QuerySnapshot createSnaphot(Timestamp readTime, final List documents) { + return QuerySnapshot.withDocuments(this, readTime, documents); } /** Checks whether the provided object is NULL or NaN. */ @@ -1232,6 +1211,11 @@ public Query select(FieldPath... fieldPaths) { return new Query(rpcContext, newOptions.build()); } + @Override + boolean isRetryableWithCursor() { + return true; + } + /** * Creates and returns a new Query that starts after the provided document (exclusive). The * starting position is relative to the order of the query. The document must contain all of the @@ -1241,6 +1225,7 @@ public Query select(FieldPath... fieldPaths) { * @return The created Query. */ @Nonnull + @Override public Query startAfter(@Nonnull DocumentSnapshot snapshot) { ImmutableList fieldOrders = createImplicitOrderBy(); Cursor cursor = createCursor(fieldOrders, snapshot, false); @@ -1601,9 +1586,30 @@ public void onCompleted() { * @return the serialized RunQueryRequest */ public RunQueryRequest toProto() { + return toRunQueryRequestBuilder(null, null, null).build(); + } + + @Override + protected RunQueryRequest.Builder toRunQueryRequestBuilder( + @Nullable final ByteString transactionId, + @Nullable final Timestamp readTime, + @Nullable ExplainOptions explainOptions) { + + // Builder for RunQueryRequest RunQueryRequest.Builder request = RunQueryRequest.newBuilder(); - request.setStructuredQuery(buildQuery()).setParent(options.getParentPath().toString()); - return request.build(); + request.setStructuredQuery(buildQuery()); + request.setParent(options.getParentPath().toString()); + if (explainOptions != null) { + request.setExplainOptions(explainOptions.toProto()); + } + if (transactionId != null) { + request.setTransaction(transactionId); + } + if (readTime != null) { + request.setReadTime(readTime.toProto()); + } + + return request; } /** @@ -1705,147 +1711,12 @@ private Value encodeValue(FieldPath fieldPath, Object value) { return encodedValue; } - private void internalStream( - final ApiStreamObserver runQueryResponseObserver, - final long startTimeNanos, - @Nullable final ByteString transactionId, - @Nullable final Timestamp readTime, - @Nullable final ExplainOptions explainOptions, - final boolean isRetryRequestWithCursor) { - TraceUtil traceUtil = getFirestore().getOptions().getTraceUtil(); - // To reduce the size of traces, we only register one event for every 100 responses - // that we receive from the server. - final int NUM_RESPONSES_PER_TRACE_EVENT = 100; - - RunQueryRequest.Builder request = RunQueryRequest.newBuilder(); - request.setStructuredQuery(buildQuery()).setParent(options.getParentPath().toString()); - - if (explainOptions != null) { - request.setExplainOptions(explainOptions.toProto()); - } - - if (transactionId != null) { - request.setTransaction(transactionId); - } - if (readTime != null) { - request.setReadTime(readTime.toProto()); - } - - TraceUtil.Span currentSpan = traceUtil.currentSpan(); - currentSpan.addEvent( - TraceUtil.SPAN_NAME_RUN_QUERY, - new ImmutableMap.Builder() - .put(ATTRIBUTE_KEY_IS_TRANSACTIONAL, transactionId != null) - .put(ATTRIBUTE_KEY_IS_RETRY_WITH_CURSOR, isRetryRequestWithCursor) - .build()); - - final AtomicReference lastReceivedDocument = new AtomicReference<>(); - - ResponseObserver observer = - new ResponseObserver() { - Timestamp readTime; - boolean firstResponse = false; - int numDocuments = 0; - - // The stream's `onComplete()` could be called more than once, - // this flag makes sure only the first one is actually processed. - boolean hasCompleted = false; - - @Override - public void onStart(StreamController streamController) {} - - @Override - public void onResponse(RunQueryResponse response) { - if (!firstResponse) { - firstResponse = true; - currentSpan.addEvent(TraceUtil.SPAN_NAME_RUN_QUERY + ": First Response"); - } - - runQueryResponseObserver.onNext(response); - - if (response.hasDocument()) { - numDocuments++; - if (numDocuments % NUM_RESPONSES_PER_TRACE_EVENT == 0) { - currentSpan.addEvent( - TraceUtil.SPAN_NAME_RUN_QUERY + ": Received " + numDocuments + " documents"); - } - Document document = response.getDocument(); - QueryDocumentSnapshot documentSnapshot = - QueryDocumentSnapshot.fromDocument( - rpcContext, Timestamp.fromProto(response.getReadTime()), document); - lastReceivedDocument.set(documentSnapshot); - } - - if (response.getDone()) { - currentSpan.addEvent( - TraceUtil.SPAN_NAME_RUN_QUERY + ": Received RunQueryResponse.Done"); - onComplete(); - } - } - - @Override - public void onError(Throwable throwable) { - QueryDocumentSnapshot cursor = lastReceivedDocument.get(); - if (shouldRetry(cursor, throwable)) { - currentSpan.addEvent( - TraceUtil.SPAN_NAME_RUN_QUERY + ": Retryable Error", - Collections.singletonMap("error.message", throwable.getMessage())); - - Query.this - .startAfter(cursor) - .internalStream( - runQueryResponseObserver, - startTimeNanos, - /* transactionId= */ null, - options.getRequireConsistency() ? cursor.getReadTime() : null, - explainOptions, - /* isRetryRequestWithCursor= */ true); - } else { - currentSpan.addEvent( - TraceUtil.SPAN_NAME_RUN_QUERY + ": Error", - Collections.singletonMap("error.message", throwable.getMessage())); - runQueryResponseObserver.onError(throwable); - } - } - - @Override - public void onComplete() { - if (hasCompleted) return; - hasCompleted = true; - currentSpan.addEvent( - TraceUtil.SPAN_NAME_RUN_QUERY + ": Completed", - Collections.singletonMap(ATTRIBUTE_KEY_DOC_COUNT, numDocuments)); - runQueryResponseObserver.onCompleted(); - } - - boolean shouldRetry(DocumentSnapshot lastDocument, Throwable t) { - if (lastDocument == null) { - // Only retry if we have received a single result. Retries for RPCs with initial - // failure are handled by Google Gax, which also implements backoff. - return false; - } - - // Do not retry EXPLAIN requests because it'd be executing - // multiple queries. This means stats would have to be aggregated, - // and that may not even make sense for many statistics. - if (explainOptions != null) { - return false; - } - - Set retryableCodes = - FirestoreSettings.newBuilder().runQuerySettings().getRetryableCodes(); - return shouldRetryQuery(t, transactionId, startTimeNanos, retryableCodes); - } - }; - - rpcContext.streamRequest(request.build(), observer, rpcContext.getClient().runQueryCallable()); - } - /** * Executes the query and returns the results as QuerySnapshot. * * @return An ApiFuture that will be resolved with the results of the Query. */ + @Override @Nonnull public ApiFuture get() { return get(null, null); @@ -1859,79 +1730,10 @@ public ApiFuture get() { * @return An ApiFuture that will be resolved with the planner information, statistics from the * query execution (if any), and the query results (if any). */ + @Override @Nonnull public ApiFuture> explain(ExplainOptions options) { - TraceUtil.Span span = - getFirestore().getOptions().getTraceUtil().startSpan(TraceUtil.SPAN_NAME_QUERY_GET); - - try (Scope ignored = span.makeCurrent()) { - final SettableApiFuture> result = SettableApiFuture.create(); - internalStream( - new ApiStreamObserver() { - @Nullable List documentSnapshots = null; - Timestamp readTime; - ExplainMetrics metrics; - - @Override - public void onNext(RunQueryResponse runQueryResponse) { - if (runQueryResponse.hasDocument()) { - if (documentSnapshots == null) { - documentSnapshots = new ArrayList<>(); - } - - Document document = runQueryResponse.getDocument(); - QueryDocumentSnapshot documentSnapshot = - QueryDocumentSnapshot.fromDocument( - rpcContext, Timestamp.fromProto(runQueryResponse.getReadTime()), document); - documentSnapshots.add(documentSnapshot); - } - - if (readTime == null) { - readTime = Timestamp.fromProto(runQueryResponse.getReadTime()); - } - - if (runQueryResponse.hasExplainMetrics()) { - metrics = new ExplainMetrics(runQueryResponse.getExplainMetrics()); - if (documentSnapshots == null && metrics.getExecutionStats() != null) { - // This indicates that the query was executed, but no documents - // had matched the query. Create an empty list. - documentSnapshots = Collections.emptyList(); - } - } - } - - @Override - public void onError(Throwable throwable) { - result.setException(throwable); - } - - @Override - public void onCompleted() { - @Nullable QuerySnapshot snapshot = null; - if (documentSnapshots != null) { - // The results for limitToLast queries need to be flipped since we reversed the - // ordering constraints before sending the query to the backend. - List resultView = - LimitType.Last.equals(Query.this.options.getLimitType()) - ? reverse(documentSnapshots) - : documentSnapshots; - snapshot = QuerySnapshot.withDocuments(Query.this, readTime, resultView); - } - result.set(new ExplainResults<>(metrics, snapshot)); - } - }, - /* startTimeNanos= */ rpcContext.getClock().nanoTime(), - /* transactionId= */ null, - /* readTime= */ null, - /* explainOptions= */ options, - /* isRetryRequestWithCursor= */ false); - - span.endAtFuture(result); - return result; - } catch (Exception error) { - span.end(error); - throw error; - } + return super.explain(options); } /** @@ -1958,69 +1760,6 @@ public ListenerRegistration addSnapshotListener( return Watch.forQuery(this).runWatch(executor, listener); } - ApiFuture get( - @Nullable ByteString transactionId, @Nullable Timestamp requestReadTime) { - TraceUtil.Span span = - getFirestore() - .getOptions() - .getTraceUtil() - .startSpan( - transactionId == null - ? TraceUtil.SPAN_NAME_QUERY_GET - : TraceUtil.SPAN_NAME_TRANSACTION_GET_QUERY); - try (Scope ignored = span.makeCurrent()) { - final SettableApiFuture result = SettableApiFuture.create(); - internalStream( - new ApiStreamObserver() { - final List documentSnapshots = new ArrayList<>(); - Timestamp responseReadTime; - - @Override - public void onNext(RunQueryResponse runQueryResponse) { - if (runQueryResponse.hasDocument()) { - Document document = runQueryResponse.getDocument(); - QueryDocumentSnapshot documentSnapshot = - QueryDocumentSnapshot.fromDocument( - rpcContext, Timestamp.fromProto(runQueryResponse.getReadTime()), document); - documentSnapshots.add(documentSnapshot); - } - if (responseReadTime == null) { - responseReadTime = Timestamp.fromProto(runQueryResponse.getReadTime()); - } - } - - @Override - public void onError(Throwable throwable) { - result.setException(throwable); - } - - @Override - public void onCompleted() { - // The results for limitToLast queries need to be flipped since we reversed the - // ordering constraints before sending the query to the backend. - List resultView = - LimitType.Last.equals(Query.this.options.getLimitType()) - ? reverse(documentSnapshots) - : documentSnapshots; - QuerySnapshot querySnapshot = - QuerySnapshot.withDocuments(Query.this, responseReadTime, resultView); - result.set(querySnapshot); - } - }, - /* startTimeNanos= */ rpcContext.getClock().nanoTime(), - transactionId, - /* readTime= */ requestReadTime, - /* explainOptions= */ null, - /* isRetryRequestWithCursor= */ false); - - span.endAtFuture(result); - return result; - } catch (Exception error) { - span.end(error); - throw error; - } - } - Comparator comparator() { Iterator iterator = options.getFieldOrders().iterator(); if (!iterator.hasNext()) { @@ -2037,10 +1776,6 @@ Comparator comparator() { return comparator.thenComparing(lastDirection.documentIdComparator); } - private static int compareDocumentId(QueryDocumentSnapshot doc1, QueryDocumentSnapshot doc2) { - return doc1.getReference().getResourcePath().compareTo(doc2.getReference().getResourcePath()); - } - /** * Helper method to append an element to an existing ImmutableList. Returns the newly created * list. @@ -2053,43 +1788,6 @@ private ImmutableList append(ImmutableList existingList, T newElement) return builder.build(); } - /** Verifies whether the given exception is retryable based on the RunQuery configuration. */ - private boolean isRetryableError(Throwable throwable, Set retryableCodes) { - if (!(throwable instanceof FirestoreException)) { - return false; - } - Status status = ((FirestoreException) throwable).getStatus(); - for (StatusCode.Code code : retryableCodes) { - if (code.equals(StatusCode.Code.valueOf(status.getCode().name()))) { - return true; - } - } - return false; - } - - /** Returns whether a query that failed in the given scenario should be retried. */ - boolean shouldRetryQuery( - Throwable throwable, - @Nullable ByteString transactionId, - long startTimeNanos, - Set retryableCodes) { - if (transactionId != null) { - // Transactional queries are retried via the transaction runner. - return false; - } - - if (!isRetryableError(throwable, retryableCodes)) { - return false; - } - - if (rpcContext.getTotalRequestTimeout().isZero()) { - return true; - } - - Duration duration = Duration.ofNanos(rpcContext.getClock().nanoTime() - startTimeNanos); - return duration.compareTo(rpcContext.getTotalRequestTimeout()) < 0; - } - /** * Returns a query that counts the documents in the result set of this query. * @@ -2128,6 +1826,171 @@ public AggregateQuery aggregate( return new AggregateQuery(this, aggregateFieldList); } + /** + * Returns a VectorQuery that can perform vector distance (similarity) search with given + * parameters. + * + *

The returned query, when executed, performs a distance (similarity) search on the specified + * `vectorField` against the given `queryVector` and returns the top documents that are closest to + * the `queryVector`. + * + *

Only documents whose `vectorField` field is a {@link VectorValue} of the same dimension as + * `queryVector` participate in the query, all other documents are ignored. + * + *

{@code VectorQuery vectorQuery = col.findNearest("embedding", new double[] {41, 42}, 10, + * VectorQuery.DistanceMeasure.EUCLIDEAN); QuerySnapshot querySnapshot = await + * vectorQuery.get().get(); DocumentSnapshot mostSimilarDocument = + * querySnapshot.getDocuments().get(0);} + * + * @param vectorField A string specifying the vector field to search on. + * @param queryVector A representation of the vector used to measure the distance from + * `vectorField` values in the documents. + * @param limit The upper bound of documents to return, must be a positive integer with a maximum + * value of 1000. + * @param distanceMeasure What type of distance is calculated when performing the query. See + * {@link VectorQuery.DistanceMeasure}. + * @return an {@link VectorQuery} that performs vector distance (similarity) search with the given + * parameters. + */ + public VectorQuery findNearest( + String vectorField, + double[] queryVector, + int limit, + VectorQuery.DistanceMeasure distanceMeasure) { + return findNearest( + FieldPath.fromDotSeparatedString(vectorField), + FieldValue.vector(queryVector), + limit, + distanceMeasure, + VectorQueryOptions.getDefaultInstance()); + } + + /** + * Returns a VectorQuery that can perform vector distance (similarity) search with given + * parameters. + * + *

The returned query, when executed, performs a distance (similarity) search on the specified + * `vectorField` against the given `queryVector` and returns the top documents that are closest to + * the `queryVector`. + * + *

Only documents whose `vectorField` field is a {@link VectorValue} of the same dimension as + * `queryVector` participate in the query, all other documents are ignored. + * + *

{@code VectorQuery vectorQuery = col.findNearest( "embedding", new double[] {41, 42}, 10, + * VectorQuery.DistanceMeasure.EUCLIDEAN, + * FindNearestOptions.newBuilder().setDistanceThreshold(0.11).setDistanceResultField("foo").build()); + * QuerySnapshot querySnapshot = await vectorQuery.get().get(); DocumentSnapshot + * mostSimilarDocument = querySnapshot.getDocuments().get(0);} + * + * @param vectorField A string specifying the vector field to search on. + * @param queryVector A representation of the vector used to measure the distance from + * `vectorField` values in the documents. + * @param limit The upper bound of documents to return, must be a positive integer with a maximum + * value of 1000. + * @param distanceMeasure What type of distance is calculated when performing the query. See + * {@link VectorQuery.DistanceMeasure}. + * @param vectorQueryOptions Optional arguments for VectorQueries, see {@link VectorQueryOptions}. + * @return an {@link VectorQuery} that performs vector distance (similarity) search with the given + * parameters. + */ + public VectorQuery findNearest( + String vectorField, + double[] queryVector, + int limit, + VectorQuery.DistanceMeasure distanceMeasure, + VectorQueryOptions vectorQueryOptions) { + return findNearest( + FieldPath.fromDotSeparatedString(vectorField), + FieldValue.vector(queryVector), + limit, + distanceMeasure, + vectorQueryOptions); + } + + /** + * Returns a VectorQuery that can perform vector distance (similarity) search with given + * parameters. + * + *

The returned query, when executed, performs a distance (similarity) search on the specified + * `vectorField` against the given `queryVector` and returns the top documents that are closest to + * the `queryVector`. + * + *

Only documents whose `vectorField` field is a {@link VectorValue} of the same dimension as + * `queryVector` participate in the query, all other documents are ignored. + * + *

{@code VectorValue queryVector = FieldValue.vector(new double[] {41, 42}); VectorQuery + * vectorQuery = col.findNearest( FieldPath.of("embedding"), queryVector, 10, + * VectorQuery.DistanceMeasure.EUCLIDEAN); QuerySnapshot querySnapshot = await + * vectorQuery.get().get(); DocumentSnapshot mostSimilarDocument = + * querySnapshot.getDocuments().get(0);} + * + * @param vectorField A {@link FieldPath} specifying the vector field to search on. + * @param queryVector The {@link VectorValue} used to measure the distance from `vectorField` + * values in the documents. + * @param limit The upper bound of documents to return, must be a positive integer with a maximum + * value of 1000. + * @param distanceMeasure What type of distance is calculated when performing the query. See + * {@link VectorQuery.DistanceMeasure}. + * @return an {@link VectorQuery} that performs vector distance (similarity) search with the given + * parameters. + */ + public VectorQuery findNearest( + FieldPath vectorField, + VectorValue queryVector, + int limit, + VectorQuery.DistanceMeasure distanceMeasure) { + return findNearest( + vectorField, queryVector, limit, distanceMeasure, VectorQueryOptions.getDefaultInstance()); + } + + /** + * Returns a VectorQuery that can perform vector distance (similarity) search with given + * parameters. + * + *

The returned query, when executed, performs a distance (similarity) search on the specified + * `vectorField` against the given `queryVector` and returns the top documents that are closest to + * the `queryVector`. + * + *

Only documents whose `vectorField` field is a {@link VectorValue} of the same dimension as + * `queryVector` participate in the query, all other documents are ignored. + * + *

{@code VectorQuery vectorQuery = col.findNearest( FieldPath.of("embedding"), queryVector, + * 10, VectorQuery.DistanceMeasure.EUCLIDEAN, + * FindNearestOptions.newBuilder().setDistanceThreshold(0.11).setDistanceResultField("foo").build()); + * QuerySnapshot querySnapshot = await vectorQuery.get().get(); DocumentSnapshot + * mostSimilarDocument = querySnapshot.getDocuments().get(0);} + * + * @param vectorField A {@link FieldPath} specifying the vector field to search on. + * @param queryVector The {@link VectorValue} used to measure the distance from `vectorField` + * values in the documents. + * @param limit The upper bound of documents to return, must be a positive integer with a maximum + * value of 1000. + * @param distanceMeasure What type of distance is calculated when performing the query. See + * {@link VectorQuery.DistanceMeasure}. + * @param vectorQueryOptions Optional arguments for VectorQueries, see {@link VectorQueryOptions}. + * @return an {@link VectorQuery} that performs vector distance (similarity) search with the given + * parameters. + */ + public VectorQuery findNearest( + FieldPath vectorField, + VectorValue queryVector, + int limit, + VectorQuery.DistanceMeasure distanceMeasure, + VectorQueryOptions vectorQueryOptions) { + if (limit <= 0) { + throw FirestoreException.forInvalidArgument( + "Not a valid positive `limit` number. `limit` must be larger than 0."); + } + + if (queryVector.size() == 0) { + throw FirestoreException.forInvalidArgument( + "Not a valid vector size. `queryVector` size must be larger than 0."); + } + + return new VectorQuery( + this, vectorField, queryVector, limit, distanceMeasure, vectorQueryOptions); + } + /** * Returns true if this Query is equal to the provided object. * diff --git a/google-cloud-firestore/src/main/java/com/google/cloud/firestore/QueryDocumentSnapshot.java b/google-cloud-firestore/src/main/java/com/google/cloud/firestore/QueryDocumentSnapshot.java index c0575586d..ab3442415 100644 --- a/google-cloud-firestore/src/main/java/com/google/cloud/firestore/QueryDocumentSnapshot.java +++ b/google-cloud-firestore/src/main/java/com/google/cloud/firestore/QueryDocumentSnapshot.java @@ -87,4 +87,8 @@ public T toObject(@Nonnull Class valueType) { Preconditions.checkNotNull(result, "Object in a QueryDocumentSnapshot should be non-null"); return result; } + + static int compareDocumentId(QueryDocumentSnapshot doc1, QueryDocumentSnapshot doc2) { + return doc1.getReference().getResourcePath().compareTo(doc2.getReference().getResourcePath()); + } } diff --git a/google-cloud-firestore/src/main/java/com/google/cloud/firestore/QuerySnapshot.java b/google-cloud-firestore/src/main/java/com/google/cloud/firestore/QuerySnapshot.java index 494a298e4..1b1c0cbf9 100644 --- a/google-cloud-firestore/src/main/java/com/google/cloud/firestore/QuerySnapshot.java +++ b/google-cloud-firestore/src/main/java/com/google/cloud/firestore/QuerySnapshot.java @@ -17,85 +17,29 @@ package com.google.cloud.firestore; import com.google.cloud.Timestamp; -import com.google.cloud.firestore.DocumentChange.Type; -import com.google.cloud.firestore.encoding.CustomClassMapper; -import com.google.common.collect.ImmutableList; -import java.util.ArrayList; -import java.util.Collections; -import java.util.Iterator; import java.util.List; -import java.util.Objects; -import javax.annotation.Nonnull; /** * A QuerySnapshot contains the results of a query. It can contain zero or more DocumentSnapshot * objects. */ -public abstract class QuerySnapshot implements Iterable { - - private final Query query; - private final Timestamp readTime; +public class QuerySnapshot extends GenericQuerySnapshot { + protected QuerySnapshot(Query query, Timestamp readTime) { + super(query, readTime, null, null); + } - protected QuerySnapshot(Query query, Timestamp readTime) { // Elevated access level for mocking. - this.query = query; - this.readTime = readTime; + protected QuerySnapshot( + Query query, + Timestamp readTime, + final List documents, + final List documentChanges) { + super(query, readTime, documents, documentChanges); } /** Creates a new QuerySnapshot representing the results of a Query with added documents. */ public static QuerySnapshot withDocuments( final Query query, Timestamp readTime, final List documents) { - return new QuerySnapshot(query, readTime) { - volatile ImmutableList documentChanges; - - @Nonnull - @Override - public List getDocuments() { - return Collections.unmodifiableList(documents); - } - - @Nonnull - @Override - public List getDocumentChanges() { - if (documentChanges == null) { - synchronized (documents) { - if (documentChanges == null) { - int size = documents.size(); - ImmutableList.Builder builder = - ImmutableList.builderWithExpectedSize(size); - for (int i = 0; i < size; ++i) { - builder.add(new DocumentChange(documents.get(i), Type.ADDED, -1, i)); - } - documentChanges = builder.build(); - } - } - } - return documentChanges; - } - - @Override - public int size() { - return documents.size(); - } - - @Override - public boolean equals(Object o) { - if (this == o) { - return true; - } - if (o == null || getClass() != o.getClass()) { - return false; - } - QuerySnapshot that = (QuerySnapshot) o; - return Objects.equals(query, that.query) - && Objects.equals(this.size(), that.size()) - && Objects.equals(this.getDocuments(), that.getDocuments()); - } - - @Override - public int hashCode() { - return Objects.hash(query, this.getDocuments()); - } - }; + return new QuerySnapshot(query, readTime, documents, null); } /** Creates a new QuerySnapshot representing a snapshot of a Query with changed documents. */ @@ -104,134 +48,6 @@ public static QuerySnapshot withChanges( Timestamp readTime, final DocumentSet documentSet, final List documentChanges) { - return new QuerySnapshot(query, readTime) { - volatile List documents; - - @Nonnull - @Override - public List getDocuments() { - if (documents == null) { - synchronized (documentSet) { - if (documents == null) { - documents = documentSet.toList(); - } - } - } - return Collections.unmodifiableList(documents); - } - - @Nonnull - @Override - public List getDocumentChanges() { - return Collections.unmodifiableList(documentChanges); - } - - @Override - public int size() { - return documentSet.size(); - } - - @Override - public boolean equals(Object o) { - if (this == o) { - return true; - } - if (o == null || getClass() != o.getClass()) { - return false; - } - QuerySnapshot that = (QuerySnapshot) o; - return Objects.equals(query, that.query) - && Objects.equals(this.size(), that.size()) - && Objects.equals(this.getDocumentChanges(), that.getDocumentChanges()) - && Objects.equals(this.getDocuments(), that.getDocuments()); - } - - @Override - public int hashCode() { - return Objects.hash(query, this.getDocumentChanges(), this.getDocuments()); - } - }; - } - - /** - * Returns the query for the snapshot. - * - * @return The backing query that produced this snapshot. - */ - @Nonnull - public Query getQuery() { - return query; - } - - /** - * Returns the time at which this snapshot was read. - * - * @return The read time of this snapshot. - */ - @Nonnull - public Timestamp getReadTime() { - return readTime; - } - - /** - * Returns the documents in this QuerySnapshot as a List in order of the query. - * - * @return The list of documents. - */ - @Nonnull - public abstract List getDocuments(); - - /** - * Returns the list of documents that changed since the last snapshot. If it's the first snapshot - * all documents will be in the list as added changes. - * - * @return The list of documents that changed since the last snapshot. - */ - @Nonnull - public abstract List getDocumentChanges(); - - /** Returns true if there are no documents in the QuerySnapshot. */ - public boolean isEmpty() { - return this.size() == 0; + return new QuerySnapshot(query, readTime, documentSet.toList(), documentChanges); } - - /** Returns the number of documents in the QuerySnapshot. */ - public abstract int size(); - - @Override - @Nonnull - public Iterator iterator() { - return getDocuments().iterator(); - } - - /** - * Returns the contents of the documents in the QuerySnapshot, converted to the provided class, as - * a list. - * - * @param clazz The POJO type used to convert the documents in the list. - */ - @Nonnull - public List toObjects(@Nonnull Class clazz) { - List documents = getDocuments(); - List results = new ArrayList<>(documents.size()); - for (DocumentSnapshot documentSnapshot : documents) { - results.add( - CustomClassMapper.convertToCustomClass( - documentSnapshot.getData(), clazz, documentSnapshot.getReference())); - } - - return results; - } - - /** - * Returns true if the document data in this QuerySnapshot equals the provided snapshot. - * - * @param obj The object to compare against. - * @return Whether this QuerySnapshot is equal to the provided object. - */ - @Override - public abstract boolean equals(Object obj); - - @Override - public abstract int hashCode(); } diff --git a/google-cloud-firestore/src/main/java/com/google/cloud/firestore/StreamableQuery.java b/google-cloud-firestore/src/main/java/com/google/cloud/firestore/StreamableQuery.java new file mode 100644 index 000000000..24be9e69c --- /dev/null +++ b/google-cloud-firestore/src/main/java/com/google/cloud/firestore/StreamableQuery.java @@ -0,0 +1,400 @@ +/* + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.cloud.firestore; + +import static com.google.cloud.firestore.telemetry.TraceUtil.*; +import static com.google.common.collect.Lists.reverse; + +import com.google.api.core.ApiFuture; +import com.google.api.core.SettableApiFuture; +import com.google.api.gax.rpc.ApiStreamObserver; +import com.google.api.gax.rpc.ResponseObserver; +import com.google.api.gax.rpc.StatusCode; +import com.google.api.gax.rpc.StreamController; +import com.google.cloud.Timestamp; +import com.google.cloud.firestore.telemetry.TraceUtil; +import com.google.cloud.firestore.v1.FirestoreSettings; +import com.google.common.collect.ImmutableMap; +import com.google.firestore.v1.Document; +import com.google.firestore.v1.RunQueryRequest; +import com.google.firestore.v1.RunQueryResponse; +import com.google.protobuf.ByteString; +import io.grpc.Status; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Set; +import java.util.concurrent.atomic.AtomicReference; +import javax.annotation.Nonnull; +import javax.annotation.Nullable; +import org.threeten.bp.Duration; + +/** + * Represents a query whose results can be streamed. If the stream fails with a retryable error, + * implementations of StreamableQuery can optionally support retries with a cursor, as indicated by + * `isRetryableWithCursor`. Retrying with a cursor means that the StreamableQuery can be resumed + * where it failed by first calling `startAfter(lastDocumentReceived)`. + */ +public abstract class StreamableQuery { + final Query.QueryOptions options; + final FirestoreRpcContext rpcContext; + + StreamableQuery(FirestoreRpcContext rpcContext, Query.QueryOptions options) { + this.rpcContext = rpcContext; + this.options = options; + } + + abstract RunQueryRequest.Builder toRunQueryRequestBuilder( + @Nullable final ByteString transactionId, + @Nullable final Timestamp readTime, + @Nullable ExplainOptions explainOptions); + + abstract boolean isRetryableWithCursor(); + + abstract StreamableQuery startAfter(@Nonnull DocumentSnapshot snapshot); + + abstract SnapshotType createSnaphot( + Timestamp readTime, final List documents); + + /** + * Gets the Firestore instance associated with this query. + * + * @return The Firestore instance associated with this query. + */ + @Nonnull + public Firestore getFirestore() { + return rpcContext.getFirestore(); + } + /** + * Executes the query and returns the results as QuerySnapshot. + * + * @return An ApiFuture that will be resolved with the results of the Query. + */ + @Nonnull + public abstract ApiFuture get(); + + /** + * Executes the query and returns the results as QuerySnapshot. + * + * @return An ApiFuture that will be resolved with the results of the Query. + */ + ApiFuture get( + @Nullable ByteString transactionId, @Nullable Timestamp requestReadTime) { + TraceUtil.Span span = + getFirestore() + .getOptions() + .getTraceUtil() + .startSpan( + transactionId == null + ? TraceUtil.SPAN_NAME_QUERY_GET + : TraceUtil.SPAN_NAME_TRANSACTION_GET_QUERY); + try (Scope ignored = span.makeCurrent()) { + final SettableApiFuture result = SettableApiFuture.create(); + internalStream( + new ApiStreamObserver() { + final List documentSnapshots = new ArrayList<>(); + Timestamp responseReadTime; + + @Override + public void onNext(RunQueryResponse runQueryResponse) { + if (runQueryResponse.hasDocument()) { + Document document = runQueryResponse.getDocument(); + QueryDocumentSnapshot documentSnapshot = + QueryDocumentSnapshot.fromDocument( + rpcContext, Timestamp.fromProto(runQueryResponse.getReadTime()), document); + documentSnapshots.add(documentSnapshot); + } + if (responseReadTime == null) { + responseReadTime = Timestamp.fromProto(runQueryResponse.getReadTime()); + } + } + + @Override + public void onError(Throwable throwable) { + result.setException(throwable); + } + + @Override + public void onCompleted() { + // The results for limitToLast queries need to be flipped since we reversed the + // ordering constraints before sending the query to the backend. + List resultView = + Query.LimitType.Last.equals(options.getLimitType()) + ? reverse(documentSnapshots) + : documentSnapshots; + SnapshotType querySnapshot = createSnaphot(responseReadTime, resultView); + result.set(querySnapshot); + } + }, + /* startTimeNanos= */ rpcContext.getClock().nanoTime(), + transactionId, + /* readTime= */ requestReadTime, + /* explainOptions= */ null, + /* isRetryRequestWithCursor= */ false); + + span.endAtFuture(result); + return result; + } catch (Exception error) { + span.end(error); + throw error; + } + } + + /** + * Plans and optionally executes this query. Returns an ApiFuture that will be resolved with the + * planner information, statistics from the query execution (if any), and the query results (if + * any). + * + * @return An ApiFuture that will be resolved with the planner information, statistics from the + * query execution (if any), and the query results (if any). + */ + @Nonnull + public ApiFuture> explain(ExplainOptions options) { + TraceUtil.Span span = + getFirestore().getOptions().getTraceUtil().startSpan(TraceUtil.SPAN_NAME_QUERY_GET); + + try (Scope ignored = span.makeCurrent()) { + final SettableApiFuture> result = SettableApiFuture.create(); + internalStream( + new ApiStreamObserver() { + @Nullable List documentSnapshots = null; + Timestamp readTime; + ExplainMetrics metrics; + + @Override + public void onNext(RunQueryResponse runQueryResponse) { + if (runQueryResponse.hasDocument()) { + if (documentSnapshots == null) { + documentSnapshots = new ArrayList<>(); + } + + Document document = runQueryResponse.getDocument(); + QueryDocumentSnapshot documentSnapshot = + QueryDocumentSnapshot.fromDocument( + rpcContext, Timestamp.fromProto(runQueryResponse.getReadTime()), document); + documentSnapshots.add(documentSnapshot); + } + + if (readTime == null) { + readTime = Timestamp.fromProto(runQueryResponse.getReadTime()); + } + + if (runQueryResponse.hasExplainMetrics()) { + metrics = new ExplainMetrics(runQueryResponse.getExplainMetrics()); + if (documentSnapshots == null && metrics.getExecutionStats() != null) { + // This indicates that the query was executed, but no documents + // had matched the query. Create an empty list. + documentSnapshots = Collections.emptyList(); + } + } + } + + @Override + public void onError(Throwable throwable) { + result.setException(throwable); + } + + @Override + public void onCompleted() { + @Nullable SnapshotType snapshot = null; + if (documentSnapshots != null) { + // The results for limitToLast queries need to be flipped since we reversed the + // ordering constraints before sending the query to the backend. + List resultView = + Query.LimitType.Last.equals(StreamableQuery.this.options.getLimitType()) + ? reverse(documentSnapshots) + : documentSnapshots; + snapshot = createSnaphot(readTime, resultView); + } + result.set(new ExplainResults<>(metrics, snapshot)); + } + }, + /* startTimeNanos= */ rpcContext.getClock().nanoTime(), + /* transactionId= */ null, + /* readTime= */ null, + /* explainOptions= */ options, + /* isRetryRequestWithCursor= */ false); + + span.endAtFuture(result); + return result; + } catch (Exception error) { + span.end(error); + throw error; + } + } + + protected void internalStream( + final ApiStreamObserver runQueryResponseObserver, + final long startTimeNanos, + @Nullable final ByteString transactionId, + @Nullable final Timestamp readTime, + @Nullable final ExplainOptions explainOptions, + final boolean isRetryRequestWithCursor) { + TraceUtil traceUtil = getFirestore().getOptions().getTraceUtil(); + // To reduce the size of traces, we only register one event for every 100 responses + // that we receive from the server. + final int NUM_RESPONSES_PER_TRACE_EVENT = 100; + + TraceUtil.Span currentSpan = traceUtil.currentSpan(); + currentSpan.addEvent( + TraceUtil.SPAN_NAME_RUN_QUERY, + new ImmutableMap.Builder() + .put(ATTRIBUTE_KEY_IS_TRANSACTIONAL, transactionId != null) + .put(ATTRIBUTE_KEY_IS_RETRY_WITH_CURSOR, isRetryRequestWithCursor) + .build()); + + final AtomicReference lastReceivedDocument = new AtomicReference<>(); + + ResponseObserver observer = + new ResponseObserver() { + Timestamp readTime; + boolean firstResponse = false; + int numDocuments = 0; + + // The stream's `onComplete()` could be called more than once, + // this flag makes sure only the first one is actually processed. + boolean hasCompleted = false; + + @Override + public void onStart(StreamController streamController) {} + + @Override + public void onResponse(RunQueryResponse response) { + if (!firstResponse) { + firstResponse = true; + currentSpan.addEvent(TraceUtil.SPAN_NAME_RUN_QUERY + ": First Response"); + } + + runQueryResponseObserver.onNext(response); + + if (response.hasDocument()) { + numDocuments++; + if (numDocuments % NUM_RESPONSES_PER_TRACE_EVENT == 0) { + currentSpan.addEvent( + TraceUtil.SPAN_NAME_RUN_QUERY + ": Received " + numDocuments + " documents"); + } + Document document = response.getDocument(); + QueryDocumentSnapshot documentSnapshot = + QueryDocumentSnapshot.fromDocument( + rpcContext, Timestamp.fromProto(response.getReadTime()), document); + lastReceivedDocument.set(documentSnapshot); + } + + if (response.getDone()) { + currentSpan.addEvent( + TraceUtil.SPAN_NAME_RUN_QUERY + ": Received RunQueryResponse.Done"); + onComplete(); + } + } + + @Override + public void onError(Throwable throwable) { + QueryDocumentSnapshot cursor = lastReceivedDocument.get(); + if (isRetryableWithCursor() && shouldRetry(cursor, throwable)) { + currentSpan.addEvent( + TraceUtil.SPAN_NAME_RUN_QUERY + ": Retryable Error", + Collections.singletonMap("error.message", throwable.getMessage())); + + startAfter(cursor) + .internalStream( + runQueryResponseObserver, + startTimeNanos, + /* transactionId= */ null, + options.getRequireConsistency() ? cursor.getReadTime() : null, + explainOptions, + /* isRetryRequestWithCursor= */ true); + } else { + currentSpan.addEvent( + TraceUtil.SPAN_NAME_RUN_QUERY + ": Error", + Collections.singletonMap("error.message", throwable.getMessage())); + runQueryResponseObserver.onError(throwable); + } + } + + @Override + public void onComplete() { + if (hasCompleted) return; + hasCompleted = true; + currentSpan.addEvent( + TraceUtil.SPAN_NAME_RUN_QUERY + ": Completed", + Collections.singletonMap(ATTRIBUTE_KEY_DOC_COUNT, numDocuments)); + runQueryResponseObserver.onCompleted(); + } + + boolean shouldRetry(DocumentSnapshot lastDocument, Throwable t) { + if (lastDocument == null) { + // Only retry if we have received a single result. Retries for RPCs with initial + // failure are handled by Google Gax, which also implements backoff. + return false; + } + + // Do not retry EXPLAIN requests because it'd be executing + // multiple queries. This means stats would have to be aggregated, + // and that may not even make sense for many statistics. + if (explainOptions != null) { + return false; + } + + Set retryableCodes = + FirestoreSettings.newBuilder().runQuerySettings().getRetryableCodes(); + return shouldRetryQuery(t, transactionId, startTimeNanos, retryableCodes); + } + }; + + rpcContext.streamRequest( + toRunQueryRequestBuilder(transactionId, readTime, explainOptions).build(), + observer, + rpcContext.getClient().runQueryCallable()); + } + + /** Returns whether a query that failed in the given scenario should be retried. */ + boolean shouldRetryQuery( + Throwable throwable, + @Nullable ByteString transactionId, + long startTimeNanos, + Set retryableCodes) { + if (transactionId != null) { + // Transactional queries are retried via the transaction runner. + return false; + } + + if (!isRetryableError(throwable, retryableCodes)) { + return false; + } + + if (rpcContext.getTotalRequestTimeout().isZero()) { + return true; + } + + Duration duration = Duration.ofNanos(rpcContext.getClock().nanoTime() - startTimeNanos); + return duration.compareTo(rpcContext.getTotalRequestTimeout()) < 0; + } + + /** Verifies whether the given exception is retryable based on the RunQuery configuration. */ + private boolean isRetryableError(Throwable throwable, Set retryableCodes) { + if (!(throwable instanceof FirestoreException)) { + return false; + } + Status status = ((FirestoreException) throwable).getStatus(); + for (StatusCode.Code code : retryableCodes) { + if (code.equals(StatusCode.Code.valueOf(status.getCode().name()))) { + return true; + } + } + return false; + } +} diff --git a/google-cloud-firestore/src/main/java/com/google/cloud/firestore/VectorQuery.java b/google-cloud-firestore/src/main/java/com/google/cloud/firestore/VectorQuery.java new file mode 100644 index 000000000..2f8c2ac42 --- /dev/null +++ b/google-cloud-firestore/src/main/java/com/google/cloud/firestore/VectorQuery.java @@ -0,0 +1,192 @@ +/* + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.cloud.firestore; + +import com.google.api.core.ApiFuture; +import com.google.cloud.Timestamp; +import com.google.firestore.v1.RunQueryRequest; +import com.google.firestore.v1.StructuredQuery; +import com.google.protobuf.ByteString; +import java.util.List; +import java.util.Objects; +import javax.annotation.Nonnull; +import javax.annotation.Nullable; + +/** + * A query that finds the documents whose vector fields are closest to a certain query vector. + * Create an instance of `VectorQuery` with {@link Query#findNearest}. + */ +public final class VectorQuery extends StreamableQuery { + final Query query; + final FieldPath vectorField; + final VectorValue queryVector; + final int limit; + final DistanceMeasure distanceMeasure; + final VectorQueryOptions options; + + /** Creates a VectorQuery */ + VectorQuery( + Query query, + FieldPath vectorField, + VectorValue queryVector, + int limit, + DistanceMeasure distanceMeasure, + VectorQueryOptions options) { + super(query.rpcContext, query.options); + + this.query = query; + this.options = options; + this.vectorField = vectorField; + this.queryVector = queryVector; + this.limit = limit; + this.distanceMeasure = distanceMeasure; + } + + /** + * Executes the query and returns the results as {@link QuerySnapshot}. + * + * @return An ApiFuture that will be resolved with the results of the VectorQuery. + */ + @Override + @Nonnull + public ApiFuture get() { + return get(null, null); + } + + /** + * Plans and optionally executes this VectorQuery. Returns an ApiFuture that will be resolved with + * the planner information, statistics from the query execution (if any), and the query results + * (if any). + * + * @return An ApiFuture that will be resolved with the planner information, statistics from the + * query execution (if any), and the query results (if any). + */ + @Override + @Nonnull + public ApiFuture> explain(ExplainOptions options) { + return super.explain(options); + } + + /** + * Returns true if this VectorQuery is equal to the provided object. + * + * @param obj The object to compare against. + * @return Whether this VectorQuery is equal to the provided object. + */ + @Override + public boolean equals(Object obj) { + if (this == obj) { + return true; + } + if (obj == null || !(obj instanceof VectorQuery)) { + return false; + } + VectorQuery otherQuery = (VectorQuery) obj; + return Objects.equals(query, otherQuery.query) + && Objects.equals(vectorField, otherQuery.vectorField) + && Objects.equals(queryVector, otherQuery.queryVector) + && Objects.equals(options, otherQuery.options) + && limit == otherQuery.limit + && distanceMeasure == otherQuery.distanceMeasure; + } + + @Override + public int hashCode() { + return Objects.hash(query, options, vectorField, queryVector, limit, distanceMeasure); + } + + @Override + protected RunQueryRequest.Builder toRunQueryRequestBuilder( + @Nullable final ByteString transactionId, + @Nullable final Timestamp readTime, + @Nullable ExplainOptions explainOptions) { + + // Builder for the base query + RunQueryRequest.Builder requestBuilder = + query.toRunQueryRequestBuilder(transactionId, readTime, explainOptions); + + // Builder for find nearest + StructuredQuery.FindNearest.Builder findNearestBuilder = + requestBuilder.getStructuredQueryBuilder().getFindNearestBuilder(); + findNearestBuilder.getQueryVectorBuilder().setMapValue(this.queryVector.toProto()); + findNearestBuilder.getLimitBuilder().setValue(this.limit); + findNearestBuilder.setDistanceMeasure(toProto(this.distanceMeasure)); + findNearestBuilder.getVectorFieldBuilder().setFieldPath(this.vectorField.toString()); + + if (this.options != null) { + if (this.options.getDistanceThreshold() != null) { + findNearestBuilder + .getDistanceThresholdBuilder() + .setValue(this.options.getDistanceThreshold().doubleValue()); + } + if (this.options.getDistanceResultField() != null) { + findNearestBuilder.setDistanceResultField(this.options.getDistanceResultField().toString()); + } + } + + return requestBuilder; + } + + private static StructuredQuery.FindNearest.DistanceMeasure toProto( + DistanceMeasure distanceMeasure) { + switch (distanceMeasure) { + case COSINE: + return StructuredQuery.FindNearest.DistanceMeasure.COSINE; + case EUCLIDEAN: + return StructuredQuery.FindNearest.DistanceMeasure.EUCLIDEAN; + case DOT_PRODUCT: + return StructuredQuery.FindNearest.DistanceMeasure.DOT_PRODUCT; + default: + return StructuredQuery.FindNearest.DistanceMeasure.UNRECOGNIZED; + } + } + + @Override + boolean isRetryableWithCursor() { + return false; + } + + @Override + VectorQuery startAfter(@Nonnull DocumentSnapshot snapshot) { + throw new RuntimeException("Not implemented"); + } + + @Override + VectorQuerySnapshot createSnaphot( + Timestamp readTime, final List documents) { + return VectorQuerySnapshot.withDocuments(this, readTime, documents); + } + + /** + * The distance measure to use when comparing vectors in a {@link VectorQuery}. + * + * @see com.google.cloud.firestore.Query#findNearest + */ + public enum DistanceMeasure { + /** + * COSINE distance compares vectors based on the angle between them, which allows you to measure + * similarity that isn't based on the vectors' magnitude. We recommend using DOT_PRODUCT with + * unit normalized vectors instead of COSINE distance, which is mathematically equivalent with + * better performance. + */ + COSINE, + /** Measures the EUCLIDEAN distance between the vectors. */ + EUCLIDEAN, + /** Similar to cosine but is affected by the magnitude of the vectors. */ + DOT_PRODUCT + } +} diff --git a/google-cloud-firestore/src/main/java/com/google/cloud/firestore/VectorQueryOptions.java b/google-cloud-firestore/src/main/java/com/google/cloud/firestore/VectorQueryOptions.java new file mode 100644 index 000000000..38ca68d23 --- /dev/null +++ b/google-cloud-firestore/src/main/java/com/google/cloud/firestore/VectorQueryOptions.java @@ -0,0 +1,161 @@ +/* + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.cloud.firestore; + +import java.util.Objects; +import javax.annotation.Nullable; + +/** + * Specifies the behavior of the {@link VectorQuery} generated by a call to {@link + * Query#findNearest}. + */ +public class VectorQueryOptions { + private final @Nullable FieldPath distanceResultField; + + private final @Nullable Double distanceThreshold; + + @Nullable + public FieldPath getDistanceResultField() { + return distanceResultField; + } + + @Nullable + public Double getDistanceThreshold() { + return distanceThreshold; + } + + VectorQueryOptions(VectorQueryOptions.Builder builder) { + this.distanceThreshold = builder.distanceThreshold; + this.distanceResultField = builder.distanceResultField; + } + + public static VectorQueryOptions.Builder newBuilder() { + return new VectorQueryOptions.Builder(); + } + + public static final class Builder { + /** + * Returns the name of the field that will be set on each returned DocumentSnapshot, which will + * contain the computed distance for the document. If `null`, then the computed distance will + * not be returned. Default value: `null`. + * + *

Set this value with {@link VectorQueryOptions.Builder#setDistanceResultField(FieldPath)} + * or {@link VectorQueryOptions.Builder#setDistanceResultField(String)}. + */ + private @Nullable FieldPath distanceResultField; + + /** + * Specifies a threshold for which no less similar documents will be returned. If `null`, then + * the computed distance will not be returned. Default value: `null`. + * + *

Set this value with {@link VectorQueryOptions.Builder#setDistanceThreshold(Double)}. + */ + private @Nullable Double distanceThreshold; + + private Builder() { + distanceThreshold = null; + distanceResultField = null; + } + + private Builder(VectorQueryOptions options) { + this.distanceThreshold = options.distanceThreshold; + this.distanceResultField = options.distanceResultField; + } + + /** + * Specifies the name of the field that will be set on each returned DocumentSnapshot, which + * will contain the computed distance for the document. If `null`, then the computed distance + * will not be returned. Default value: `null`. + * + * @param fieldPath A string value specifying the distance result field. + */ + public Builder setDistanceResultField(@Nullable String fieldPath) { + this.distanceResultField = FieldPath.fromDotSeparatedString(fieldPath); + return this; + } + + /** + * Specifies the name of the field that will be set on each returned DocumentSnapshot, which + * will contain the computed distance for the document. If `null`, then the computed distance + * will not be returned. Default value: `null`. + * + * @param fieldPath A {@link FieldPath} value specifying the distance result field. + */ + public Builder setDistanceResultField(@Nullable FieldPath fieldPath) { + this.distanceResultField = fieldPath; + return this; + } + + /** + * Specifies a threshold for which no less similar documents will be returned. The behavior of + * the specified `distanceMeasure` will affect the meaning of the distance threshold. + * + *

+ * + *

    + *
  • For `distanceMeasure: "EUCLIDEAN"`, the meaning of `distanceThreshold` is: {@code + * SELECT docs WHERE euclidean_distance <= distanceThreshold} + *
  • For `distanceMeasure: "COSINE"`, the meaning of `distanceThreshold` is: {@code SELECT + * docs WHERE cosine_distance <= distanceThreshold} + *
  • For `distanceMeasure: "DOT_PRODUCT"`, the meaning of `distanceThreshold` is: {@code + * SELECT docs WHERE dot_product_distance >= distanceThreshold} + *
+ * + *

If `null`, then the computed distance will not be returned. Default value: `null`. + * + * @param distanceThreshold A Double value specifying the distance threshold. + */ + public Builder setDistanceThreshold(@Nullable Double distanceThreshold) { + this.distanceThreshold = distanceThreshold; + return this; + } + + public VectorQueryOptions build() { + return new VectorQueryOptions(this); + } + } + + /** + * Returns true if this VectorQueryOptions is equal to the provided object. + * + * @param obj The object to compare against. + * @return Whether this VectorQueryOptions is equal to the provided object. + */ + @Override + public boolean equals(Object obj) { + if (this == obj) { + return true; + } + if (obj == null || !(obj instanceof VectorQueryOptions)) { + return false; + } + VectorQueryOptions otherOptions = (VectorQueryOptions) obj; + return Objects.equals(distanceResultField, otherOptions.distanceResultField) + && Objects.equals(distanceThreshold, otherOptions.distanceThreshold); + } + + /** Default VectorQueryOptions instance. */ + private static VectorQueryOptions DEFAULT = newBuilder().build(); + + /** + * Returns a default {@code FirestoreOptions} instance. Note: package private until API review can + * be completed. + */ + static VectorQueryOptions getDefaultInstance() { + return DEFAULT; + } +} diff --git a/google-cloud-firestore/src/main/java/com/google/cloud/firestore/VectorQuerySnapshot.java b/google-cloud-firestore/src/main/java/com/google/cloud/firestore/VectorQuerySnapshot.java new file mode 100644 index 000000000..528512fad --- /dev/null +++ b/google-cloud-firestore/src/main/java/com/google/cloud/firestore/VectorQuerySnapshot.java @@ -0,0 +1,43 @@ +/* + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.cloud.firestore; + +import com.google.cloud.Timestamp; +import java.util.List; + +/** + * A VectorQuerySnapshot contains the results of a VectorQuery. It can contain zero or more + * DocumentSnapshot objects. + */ +public class VectorQuerySnapshot extends GenericQuerySnapshot { + protected VectorQuerySnapshot( + VectorQuery query, + Timestamp readTime, + final List documents, + final List documentChanges) { + super(query, readTime, documents, documentChanges); + } + + /** + * Creates a new VectorQuerySnapshot representing the results of a VectorQuery with added + * documents. + */ + public static VectorQuerySnapshot withDocuments( + final VectorQuery query, Timestamp readTime, final List documents) { + return new VectorQuerySnapshot(query, readTime, documents, null); + } +} diff --git a/google-cloud-firestore/src/main/java/com/google/cloud/firestore/VectorValue.java b/google-cloud-firestore/src/main/java/com/google/cloud/firestore/VectorValue.java index b95e8d08a..347e721f1 100644 --- a/google-cloud-firestore/src/main/java/com/google/cloud/firestore/VectorValue.java +++ b/google-cloud-firestore/src/main/java/com/google/cloud/firestore/VectorValue.java @@ -69,4 +69,11 @@ public int hashCode() { MapValue toProto() { return UserDataConverter.encodeVector(this.values); } + + /** + * Returns the number of dimensions of the vector. Note: package private until API review is done. + */ + int size() { + return this.values.length; + } } diff --git a/google-cloud-firestore/src/test/java/com/google/cloud/firestore/LocalFirestoreHelper.java b/google-cloud-firestore/src/test/java/com/google/cloud/firestore/LocalFirestoreHelper.java index 70d497b75..5ccd9f163 100644 --- a/google-cloud-firestore/src/test/java/com/google/cloud/firestore/LocalFirestoreHelper.java +++ b/google-cloud-firestore/src/test/java/com/google/cloud/firestore/LocalFirestoreHelper.java @@ -63,10 +63,7 @@ import com.google.firestore.v1.Write; import com.google.gson.Gson; import com.google.gson.reflect.TypeToken; -import com.google.protobuf.ByteString; -import com.google.protobuf.Empty; -import com.google.protobuf.Message; -import com.google.protobuf.NullValue; +import com.google.protobuf.*; import com.google.type.LatLng; import java.lang.reflect.Type; import java.math.BigInteger; @@ -649,6 +646,39 @@ public static StructuredQuery filter(StructuredQuery.FieldFilter.Operator operat return filter(operator, "foo", "bar"); } + public static StructuredQuery findNearest( + String fieldPath, + double[] queryVector, + int limit, + StructuredQuery.FindNearest.DistanceMeasure measure) { + ArrayValue.Builder vectorArrayBuilder = ArrayValue.newBuilder(); + for (double d : queryVector) { + vectorArrayBuilder.addValues(Value.newBuilder().setDoubleValue(d)); + } + + StructuredQuery.FindNearest.Builder findNearest = + StructuredQuery.FindNearest.newBuilder() + .setVectorField(StructuredQuery.FieldReference.newBuilder().setFieldPath(fieldPath)) + .setQueryVector( + Value.newBuilder() + .setMapValue( + MapValue.newBuilder() + .putFields( + "__type__", Value.newBuilder().setStringValue("__vector__").build()) + .putFields( + "value", + Value.newBuilder() + .setArrayValue(vectorArrayBuilder.build()) + .build()))) + .setLimit(Int32Value.newBuilder().setValue(limit)) + .setDistanceMeasure(measure); + + StructuredQuery.Builder structuredQuery = StructuredQuery.newBuilder(); + structuredQuery.setFindNearest(findNearest.build()); + + return structuredQuery.build(); + } + public static StructuredQuery filter( StructuredQuery.FieldFilter.Operator operator, String path, String value) { return filter(operator, path, string(value)); diff --git a/google-cloud-firestore/src/test/java/com/google/cloud/firestore/VectorQueryTest.java b/google-cloud-firestore/src/test/java/com/google/cloud/firestore/VectorQueryTest.java new file mode 100644 index 000000000..717cf6e66 --- /dev/null +++ b/google-cloud-firestore/src/test/java/com/google/cloud/firestore/VectorQueryTest.java @@ -0,0 +1,482 @@ +/* + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.cloud.firestore; + +import static com.google.cloud.firestore.LocalFirestoreHelper.COLLECTION_ID; +import static com.google.cloud.firestore.LocalFirestoreHelper.DOCUMENT_NAME; +import static com.google.cloud.firestore.LocalFirestoreHelper.findNearest; +import static com.google.cloud.firestore.LocalFirestoreHelper.query; +import static com.google.cloud.firestore.LocalFirestoreHelper.queryResponse; +import static com.google.cloud.firestore.LocalFirestoreHelper.queryResponseWithDone; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotEquals; +import static org.junit.Assert.assertThrows; +import static org.junit.Assert.assertTrue; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.doReturn; + +import com.google.api.gax.rpc.ResponseObserver; +import com.google.cloud.firestore.spi.v1.FirestoreRpc; +import com.google.firestore.v1.RunQueryRequest; +import com.google.firestore.v1.RunQueryResponse; +import com.google.firestore.v1.StructuredQuery; +import io.grpc.Status; +import java.util.List; +import java.util.concurrent.ExecutionException; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.mockito.ArgumentCaptor; +import org.mockito.Captor; +import org.mockito.Mockito; +import org.mockito.Spy; +import org.mockito.junit.MockitoJUnitRunner; + +@RunWith(MockitoJUnitRunner.class) +public class VectorQueryTest { + + @Spy + private final FirestoreImpl firestoreMock = + new FirestoreImpl( + FirestoreOptions.newBuilder().setProjectId("test-project").build(), + Mockito.mock(FirestoreRpc.class)); + + @Captor private ArgumentCaptor runQuery; + + @Captor private ArgumentCaptor> streamObserverCapture; + + private Query queryA; + private Query queryB; + + private QueryTest.MockClock clock; + + @Before + public void before() { + clock = new QueryTest.MockClock(); + doReturn(clock).when(firestoreMock).getClock(); + + queryA = firestoreMock.collection(COLLECTION_ID); + queryB = firestoreMock.collection(COLLECTION_ID).whereEqualTo("foo", "bar"); + } + + @Test + public void isEqual() { + Query queryA = firestoreMock.collection("collectionId").whereEqualTo("foo", "bar"); + Query queryB = firestoreMock.collection("collectionId").whereEqualTo("foo", "bar"); + Query queryC = firestoreMock.collection("collectionId"); + + assertEquals( + queryA.findNearest( + "embedding", new double[] {40, 41, 42}, 10, VectorQuery.DistanceMeasure.COSINE), + queryB.findNearest( + "embedding", new double[] {40, 41, 42}, 10, VectorQuery.DistanceMeasure.COSINE)); + + assertEquals( + queryA.findNearest( + "embedding", new double[] {40, 41, 42}, 10, VectorQuery.DistanceMeasure.EUCLIDEAN), + queryB.findNearest( + "embedding", new double[] {40, 41, 42}, 10, VectorQuery.DistanceMeasure.EUCLIDEAN)); + + assertEquals( + queryA.findNearest( + "embedding", + new double[] {40, 41, 42}, + 10, + VectorQuery.DistanceMeasure.EUCLIDEAN, + VectorQueryOptions.newBuilder().setDistanceThreshold(0.125).build()), + queryB.findNearest( + "embedding", + new double[] {40, 41, 42}, + 10, + VectorQuery.DistanceMeasure.EUCLIDEAN, + VectorQueryOptions.newBuilder().setDistanceThreshold(0.125).build())); + + assertEquals( + queryA.findNearest( + "embedding", + new double[] {40, 41, 42}, + 10, + VectorQuery.DistanceMeasure.EUCLIDEAN, + VectorQueryOptions.newBuilder() + .setDistanceThreshold(0.125) + .setDistanceResultField(FieldPath.of("foo")) + .build()), + queryB.findNearest( + "embedding", + new double[] {40, 41, 42}, + 10, + VectorQuery.DistanceMeasure.EUCLIDEAN, + VectorQueryOptions.newBuilder() + .setDistanceThreshold(0.125) + .setDistanceResultField(FieldPath.of("foo")) + .build())); + + assertEquals( + queryA.findNearest( + "embedding", + new double[] {40, 41, 42}, + 10, + VectorQuery.DistanceMeasure.EUCLIDEAN, + VectorQueryOptions.newBuilder().setDistanceResultField("distance").build()), + queryB.findNearest( + "embedding", + new double[] {40, 41, 42}, + 10, + VectorQuery.DistanceMeasure.EUCLIDEAN, + VectorQueryOptions.newBuilder() + .setDistanceResultField(FieldPath.of("distance")) + .build())); + + assertNotEquals( + queryA.findNearest( + "embedding", new double[] {40, 41, 42}, 10, VectorQuery.DistanceMeasure.COSINE), + queryC.findNearest( + "embedding", new double[] {40, 41, 42}, 10, VectorQuery.DistanceMeasure.COSINE)); + + assertNotEquals( + queryA.findNearest( + "embedding", new double[] {40, 41, 42}, 10, VectorQuery.DistanceMeasure.COSINE), + queryB.findNearest( + "embedding", new double[] {40, 42}, 10, VectorQuery.DistanceMeasure.COSINE)); + + assertNotEquals( + queryA.findNearest( + "embedding", new double[] {40, 41, 42}, 10, VectorQuery.DistanceMeasure.COSINE), + queryB.findNearest( + "embedding", new double[] {40, 41, 42}, 1000, VectorQuery.DistanceMeasure.COSINE)); + + assertNotEquals( + queryA.findNearest( + "embedding", new double[] {40, 41, 42}, 10, VectorQuery.DistanceMeasure.COSINE), + queryB.findNearest( + "embedding", new double[] {40, 41, 42}, 10, VectorQuery.DistanceMeasure.EUCLIDEAN)); + + assertNotEquals( + queryA.findNearest( + "embedding", + new double[] {40, 41, 42}, + 10, + VectorQuery.DistanceMeasure.EUCLIDEAN, + VectorQueryOptions.newBuilder().setDistanceThreshold(0.125).build()), + queryB.findNearest( + "embedding", + new double[] {40, 41, 42}, + 10, + VectorQuery.DistanceMeasure.EUCLIDEAN, + VectorQueryOptions.newBuilder().setDistanceThreshold(1.125).build())); + + assertNotEquals( + queryA.findNearest( + "embedding", new double[] {40, 41, 42}, 10, VectorQuery.DistanceMeasure.EUCLIDEAN), + queryB.findNearest( + "embedding", + new double[] {40, 41, 42}, + 10, + VectorQuery.DistanceMeasure.EUCLIDEAN, + VectorQueryOptions.newBuilder().setDistanceThreshold(1.0).build())); + + assertNotEquals( + queryA.findNearest( + "embedding", + new double[] {40, 41, 42}, + 10, + VectorQuery.DistanceMeasure.EUCLIDEAN, + VectorQueryOptions.newBuilder().setDistanceThreshold(1.0).build()), + queryB.findNearest( + "embedding", new double[] {40, 41, 42}, 10, VectorQuery.DistanceMeasure.EUCLIDEAN)); + + assertNotEquals( + queryA.findNearest( + "embedding", + new double[] {40, 41, 42}, + 10, + VectorQuery.DistanceMeasure.EUCLIDEAN, + VectorQueryOptions.newBuilder().setDistanceResultField("distance").build()), + queryB.findNearest( + "embedding", + new double[] {40, 41, 42}, + 10, + VectorQuery.DistanceMeasure.EUCLIDEAN, + VectorQueryOptions.newBuilder().setDistanceResultField("result").build())); + + assertNotEquals( + queryA.findNearest( + "embedding", + new double[] {40, 41, 42}, + 10, + VectorQuery.DistanceMeasure.EUCLIDEAN, + VectorQueryOptions.newBuilder() + .setDistanceResultField(FieldPath.of("distance")) + .build()), + queryB.findNearest( + "embedding", + new double[] {40, 41, 42}, + 10, + VectorQuery.DistanceMeasure.EUCLIDEAN, + VectorQueryOptions.newBuilder() + .setDistanceResultField(FieldPath.of("result")) + .build())); + + assertNotEquals( + queryA.findNearest( + "embedding", new double[] {40, 41, 42}, 10, VectorQuery.DistanceMeasure.EUCLIDEAN), + queryB.findNearest( + "embedding", + new double[] {40, 41, 42}, + 10, + VectorQuery.DistanceMeasure.EUCLIDEAN, + VectorQueryOptions.newBuilder().setDistanceResultField("result").build())); + + assertNotEquals( + queryA.findNearest( + "embedding", + new double[] {40, 41, 42}, + 10, + VectorQuery.DistanceMeasure.EUCLIDEAN, + VectorQueryOptions.newBuilder().setDistanceResultField("distance").build()), + queryB.findNearest( + "embedding", new double[] {40, 41, 42}, 10, VectorQuery.DistanceMeasure.EUCLIDEAN)); + } + + @Test + public void validatesInputsLimit() { + String expectedExceptionMessage = ".*Not a valid positive `limit` number.*"; + Throwable exception = + assertThrows( + RuntimeException.class, + () -> + queryA.findNearest( + "embedding", new double[] {10, 100}, 0, VectorQuery.DistanceMeasure.EUCLIDEAN)); + assertTrue(exception.getMessage().matches(expectedExceptionMessage)); + } + + @Test + public void validatesInputsVectorSize() { + String expectedExceptionMessage = ".*Not a valid vector size.*"; + Throwable exception = + assertThrows( + RuntimeException.class, + () -> + queryA.findNearest( + "embedding", new double[0], 10, VectorQuery.DistanceMeasure.EUCLIDEAN)); + assertTrue(exception.getMessage().matches(expectedExceptionMessage)); + } + + @Test + public void successfulReturnWithoutOnComplete() throws Exception { + doAnswer( + queryResponseWithDone( + /* callWithoutOnComplete */ true, DOCUMENT_NAME + "1", DOCUMENT_NAME + "2")) + .when(firestoreMock) + .streamRequest(runQuery.capture(), streamObserverCapture.capture(), any()); + + VectorQuerySnapshot snapshot = + queryA + .findNearest("vector", new double[] {1, -9.5}, 10, VectorQuery.DistanceMeasure.COSINE) + .get() + .get(); + + assertEquals(2, snapshot.size()); + assertEquals(false, snapshot.isEmpty()); + assertTrue((DOCUMENT_NAME + "1").endsWith(snapshot.getDocuments().get(0).getId())); + assertTrue((DOCUMENT_NAME + "2").endsWith(snapshot.getDocuments().get(1).getId())); + assertEquals(2, snapshot.getDocumentChanges().size()); + assertEquals(1, snapshot.getReadTime().getSeconds()); + assertEquals(2, snapshot.getReadTime().getNanos()); + + assertEquals(1, runQuery.getAllValues().size()); + assertEquals( + query( + findNearest( + "vector", + new double[] {1, -9.5}, + 10, + StructuredQuery.FindNearest.DistanceMeasure.COSINE)), + runQuery.getValue()); + } + + @Test + public void successfulReturn() throws Exception { + doAnswer(queryResponse(DOCUMENT_NAME + "1", DOCUMENT_NAME + "2")) + .when(firestoreMock) + .streamRequest(runQuery.capture(), streamObserverCapture.capture(), any()); + + VectorQuerySnapshot snapshot = + queryA + .findNearest( + "vector", new double[] {1, -9.5}, 10, VectorQuery.DistanceMeasure.DOT_PRODUCT) + .get() + .get(); + + assertEquals(2, snapshot.size()); + assertEquals(false, snapshot.isEmpty()); + assertTrue((DOCUMENT_NAME + "1").endsWith(snapshot.getDocuments().get(0).getId())); + assertTrue((DOCUMENT_NAME + "2").endsWith(snapshot.getDocuments().get(1).getId())); + assertEquals(2, snapshot.getDocumentChanges().size()); + assertEquals(1, snapshot.getReadTime().getSeconds()); + assertEquals(2, snapshot.getReadTime().getNanos()); + + assertEquals(1, runQuery.getAllValues().size()); + assertEquals( + query( + findNearest( + "vector", + new double[] {1, -9.5}, + 10, + StructuredQuery.FindNearest.DistanceMeasure.DOT_PRODUCT)), + runQuery.getValue()); + } + + @Test + public void handlesStreamExceptionRetryableError() { + final boolean[] returnError = new boolean[] {true}; + + doAnswer( + invocation -> { + if (returnError[0]) { + returnError[0] = false; + return queryResponse( + FirestoreException.forServerRejection( + Status.DEADLINE_EXCEEDED, "Simulated test failure"), + DOCUMENT_NAME + "1", + DOCUMENT_NAME + "2") + .answer(invocation); + } else { + return queryResponse(DOCUMENT_NAME + "3").answer(invocation); + } + }) + .when(firestoreMock) + .streamRequest(runQuery.capture(), streamObserverCapture.capture(), any()); + + ExecutionException e = + assertThrows( + ExecutionException.class, + () -> + queryA + .findNearest( + "vector", + new double[] {1, -9.5}, + 10, + VectorQuery.DistanceMeasure.DOT_PRODUCT) + .get() + .get()); + + // Verify the requests + List requests = runQuery.getAllValues(); + assertEquals(1, requests.size()); + + assertTrue(requests.get(0).getStructuredQuery().hasFindNearest()); + + assertEquals(e.getCause().getClass(), FirestoreException.class); + assertTrue(e.getCause().getMessage().matches(".*Simulated test failure.*")); + } + + @Test + public void handlesStreamExceptionNonRetryableError() { + final boolean[] returnError = new boolean[] {true}; + + doAnswer( + invocation -> { + if (returnError[0]) { + returnError[0] = false; + return queryResponse( + FirestoreException.forServerRejection( + Status.PERMISSION_DENIED, "Simulated test failure")) + .answer(invocation); + } else { + return queryResponse(DOCUMENT_NAME + "3").answer(invocation); + } + }) + .when(firestoreMock) + .streamRequest(runQuery.capture(), streamObserverCapture.capture(), any()); + + ExecutionException e = + assertThrows( + ExecutionException.class, + () -> + queryA + .findNearest( + "vector", + new double[] {1, -9.5}, + 10, + VectorQuery.DistanceMeasure.DOT_PRODUCT) + .get() + .get()); + + // Verify the requests + List requests = runQuery.getAllValues(); + assertEquals(1, requests.size()); + + assertTrue(requests.get(0).getStructuredQuery().hasFindNearest()); + + assertEquals(e.getCause().getClass(), FirestoreException.class); + assertTrue(e.getCause().getMessage().matches(".*Simulated test failure.*")); + } + + @Test + public void vectorQuerySnapshotEquality() throws Exception { + final int[] count = {0}; + + doAnswer( + invocation -> { + switch (count[0]++) { + case 0: + return queryResponse(DOCUMENT_NAME + "3", DOCUMENT_NAME + "4").answer(invocation); + case 1: + return queryResponse(DOCUMENT_NAME + "3", DOCUMENT_NAME + "4").answer(invocation); + case 2: + return queryResponse(DOCUMENT_NAME + "3", DOCUMENT_NAME + "4").answer(invocation); + case 3: + return queryResponse(DOCUMENT_NAME + "4").answer(invocation); + default: + return queryResponse().answer(invocation); + } + }) + .when(firestoreMock) + .streamRequest(runQuery.capture(), streamObserverCapture.capture(), any()); + + VectorQuerySnapshot snapshotA = + queryA + .findNearest( + "vector", new double[] {1, -9.5}, 10, VectorQuery.DistanceMeasure.DOT_PRODUCT) + .get() + .get(); + VectorQuerySnapshot snapshotB = + queryA + .findNearest( + "vector", new double[] {1, -9.5}, 10, VectorQuery.DistanceMeasure.DOT_PRODUCT) + .get() + .get(); + VectorQuerySnapshot snapshotC = + queryB + .findNearest( + "vector", new double[] {1, -9.5}, 10, VectorQuery.DistanceMeasure.DOT_PRODUCT) + .get() + .get(); + VectorQuerySnapshot snapshotD = + queryA + .findNearest( + "vector", new double[] {1, -9.5}, 10, VectorQuery.DistanceMeasure.DOT_PRODUCT) + .get() + .get(); + + assertEquals(snapshotA, snapshotB); + assertNotEquals(snapshotA, snapshotC); + assertNotEquals(snapshotA, snapshotD); + } +} diff --git a/google-cloud-firestore/src/test/java/com/google/cloud/firestore/it/ITQueryCountTest.java b/google-cloud-firestore/src/test/java/com/google/cloud/firestore/it/ITQueryCountTest.java index fce1d87d6..438762e14 100644 --- a/google-cloud-firestore/src/test/java/com/google/cloud/firestore/it/ITQueryCountTest.java +++ b/google-cloud-firestore/src/test/java/com/google/cloud/firestore/it/ITQueryCountTest.java @@ -21,6 +21,7 @@ import static com.google.cloud.firestore.it.TestHelper.isRunningAgainstFirestoreEmulator; import static com.google.common.truth.Truth.assertThat; import static java.util.Collections.singletonMap; +import static org.junit.Assert.assertThat; import static org.junit.Assert.assertThrows; import static org.junit.Assume.assumeFalse; import static org.junit.Assume.assumeTrue; diff --git a/google-cloud-firestore/src/test/java/com/google/cloud/firestore/it/ITQueryFindNearestTest.java b/google-cloud-firestore/src/test/java/com/google/cloud/firestore/it/ITQueryFindNearestTest.java new file mode 100644 index 000000000..123e9d4a9 --- /dev/null +++ b/google-cloud-firestore/src/test/java/com/google/cloud/firestore/it/ITQueryFindNearestTest.java @@ -0,0 +1,933 @@ +/* + * Copyright 2023 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.cloud.firestore.it; + +import static com.google.cloud.firestore.LocalFirestoreHelper.autoId; +import static com.google.cloud.firestore.LocalFirestoreHelper.map; +import static com.google.cloud.firestore.it.TestHelper.await; +import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.*; + +import com.google.cloud.firestore.*; +import java.time.Duration; +import java.util.Arrays; +import java.util.Map; +import javax.annotation.Nullable; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TestName; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class ITQueryFindNearestTest extends ITBaseTest { + static String testId = autoId(); + + @Rule public TestName testName = new TestName(); + + private String getUniqueTestId() { + return testId + "-" + testName.getMethodName(); + } + + private CollectionReference testCollection() { + String collectionPath = "index-test-collection"; + return firestore.collection(collectionPath); + } + + private CollectionReference testCollectionWithDocs(Map> docs) + throws InterruptedException { + CollectionReference collection = testCollection(); + CollectionReference writer = firestore.collection(collection.getId()); + writeAllDocs(writer, docs); + return collection; + } + + private String getUniqueDocId(String docId) { + return testId + docId; + } + + public void writeAllDocs(CollectionReference collection, Map> docs) + throws InterruptedException { + for (Map.Entry> doc : docs.entrySet()) { + Map data = doc.getValue(); + data.put("testId", getUniqueTestId()); + data.put("docId", getUniqueDocId(doc.getKey())); + await(collection.add(data)); + } + } + + @Test + public void findNearestWithEuclideanDistance() throws Exception { + CollectionReference collection = + testCollectionWithDocs( + map( + "a", map("foo", "bar"), + "b", map("foo", "xxx", "embedding", FieldValue.vector(new double[] {10, 10})), + "c", map("foo", "bar", "embedding", FieldValue.vector(new double[] {1, 1})), + "d", map("foo", "bar", "embedding", FieldValue.vector(new double[] {10, 0})), + "e", map("foo", "bar", "embedding", FieldValue.vector(new double[] {20, 0})), + "f", map("foo", "bar", "embedding", FieldValue.vector(new double[] {100, 100})))); + + VectorQuery vectorQuery = + collection + .whereEqualTo("foo", "bar") + .whereEqualTo("testId", getUniqueTestId()) + .findNearest( + "embedding", new double[] {10, 10}, 3, VectorQuery.DistanceMeasure.EUCLIDEAN); + + VectorQuerySnapshot snapshot = vectorQuery.get().get(); + + assertThat(snapshot.size()).isEqualTo(3); + + assertThat(snapshot.getDocuments().get(0).get("docId")).isEqualTo(getUniqueDocId("d")); + assertThat(snapshot.getDocuments().get(1).get("docId")).isEqualTo(getUniqueDocId("c")); + assertThat(snapshot.getDocuments().get(2).get("docId")).isEqualTo(getUniqueDocId("e")); + } + + @Test + public void findNearestWithEuclideanDistanceFirestoreTypeOverride() throws Exception { + CollectionReference collection = + testCollectionWithDocs( + map( + "a", map("foo", "bar"), + "b", map("foo", "xxx", "embedding", FieldValue.vector(new double[] {10, 10})), + "c", map("foo", "bar", "embedding", FieldValue.vector(new double[] {1, 1})), + "d", map("foo", "bar", "embedding", FieldValue.vector(new double[] {10, 0})), + "e", map("foo", "bar", "embedding", FieldValue.vector(new double[] {20, 0})), + "f", map("foo", "bar", "embedding", FieldValue.vector(new double[] {100, 100})))); + + VectorQuery vectorQuery = + collection + .whereEqualTo("foo", "bar") + .whereEqualTo("testId", getUniqueTestId()) + .findNearest( + FieldPath.of("embedding"), + FieldValue.vector(new double[] {10, 10}), + 3, + VectorQuery.DistanceMeasure.EUCLIDEAN); + + VectorQuerySnapshot snapshot = vectorQuery.get().get(); + + assertThat(snapshot.size()).isEqualTo(3); + + assertThat(snapshot.getDocuments().get(0).get("docId")).isEqualTo(getUniqueDocId("d")); + assertThat(snapshot.getDocuments().get(1).get("docId")).isEqualTo(getUniqueDocId("c")); + assertThat(snapshot.getDocuments().get(2).get("docId")).isEqualTo(getUniqueDocId("e")); + } + + @Test + public void findNearestWithCosineDistance() throws Exception { + CollectionReference collection = + testCollectionWithDocs( + map( + "a", + map("foo", "bar"), + "b", + map("foo", "xxx", "embedding", FieldValue.vector(new double[] {10, 10})), + "c", + map("foo", "bar", "embedding", FieldValue.vector(new double[] {1, 1})), + "d", + map("foo", "bar", "embedding", FieldValue.vector(new double[] {20, 0})), + "e", + map("foo", "bar", "embedding", FieldValue.vector(new double[] {10, 0})), + "f", + map("foo", "bar", "embedding", FieldValue.vector(new double[] {100, 100})))); + + VectorQuery vectorQuery = + collection + .whereEqualTo("foo", "bar") + .whereEqualTo("testId", getUniqueTestId()) + .findNearest("embedding", new double[] {10, 10}, 3, VectorQuery.DistanceMeasure.COSINE); + + VectorQuerySnapshot snapshot = vectorQuery.get().get(); + + assertThat(snapshot.size()).isEqualTo(3); + + assertTrue( + Arrays.asList(getUniqueDocId("f"), getUniqueDocId("c")) + .contains(snapshot.getDocuments().get(0).get("docId"))); + assertTrue( + Arrays.asList(getUniqueDocId("f"), getUniqueDocId("c")) + .contains(snapshot.getDocuments().get(1).get("docId"))); + assertTrue( + Arrays.asList(getUniqueDocId("d"), getUniqueDocId("e")) + .contains(snapshot.getDocuments().get(2).get("docId"))); + } + + @Test + public void findNearestWithDotProductDistance() throws Exception { + CollectionReference collection = + testCollectionWithDocs( + map( + "a", + map("foo", "bar"), + "b", + map("foo", "xxx", "embedding", FieldValue.vector(new double[] {10, 10})), + "c", + map("foo", "bar", "embedding", FieldValue.vector(new double[] {1, 1})), + "d", + map("foo", "bar", "embedding", FieldValue.vector(new double[] {10, 0})), + "e", + map("foo", "bar", "embedding", FieldValue.vector(new double[] {20, 0})), + "f", + map("foo", "bar", "embedding", FieldValue.vector(new double[] {100, 100})))); + + VectorQuery vectorQuery = + collection + .whereEqualTo("foo", "bar") + .whereEqualTo("testId", getUniqueTestId()) + .findNearest( + "embedding", new double[] {10, 10}, 3, VectorQuery.DistanceMeasure.DOT_PRODUCT); + + VectorQuerySnapshot snapshot = vectorQuery.get().get(); + + assertThat(snapshot.size()).isEqualTo(3); + + assertThat(snapshot.getDocuments().get(0).get("docId")).isEqualTo(getUniqueDocId("f")); + assertThat(snapshot.getDocuments().get(1).get("docId")).isEqualTo(getUniqueDocId("e")); + assertThat(snapshot.getDocuments().get(2).get("docId")).isEqualTo(getUniqueDocId("d")); + } + + @Test + public void findNearestSkipsFieldsOfWrongTypes() throws Exception { + CollectionReference collection = + testCollectionWithDocs( + map( + "a", + map("foo", "bar"), + "b", + map("foo", "bar", "embedding", Arrays.asList(10.0, 10.0)), + "c", + map("foo", "bar", "embedding", "not actually a vector"), + "d", + map("foo", "bar", "embedding", null), + "e", + map("foo", "bar", "embedding", FieldValue.vector(new double[] {9, 9})), + "f", + map("foo", "bar", "embedding", FieldValue.vector(new double[] {50, 50})), + "g", + map("foo", "bar", "embedding", FieldValue.vector(new double[] {100, 100})))); + + VectorQuery vectorQuery = + collection + .whereEqualTo("foo", "bar") + .whereEqualTo("testId", getUniqueTestId()) + .findNearest( + "embedding", new double[] {10, 10}, 100, VectorQuery.DistanceMeasure.EUCLIDEAN); + + VectorQuerySnapshot snapshot = vectorQuery.get().get(); + + assertThat(snapshot.size()).isEqualTo(3); + + assertThat(snapshot.getDocuments().get(0).get("docId")).isEqualTo(getUniqueDocId("e")); + assertThat(snapshot.getDocuments().get(1).get("docId")).isEqualTo(getUniqueDocId("f")); + assertThat(snapshot.getDocuments().get(2).get("docId")).isEqualTo(getUniqueDocId("g")); + } + + @Test + public void findNearestIgnoresMismatchingDimensions() throws Exception { + CollectionReference collection = + testCollectionWithDocs( + map( + "a", + map("foo", "bar"), + "b", + map("foo", "bar", "embedding", FieldValue.vector(new double[] {10})), + "c", + map("foo", "bar", "embedding", FieldValue.vector(new double[] {9, 9})), + "d", + map("foo", "bar", "embedding", FieldValue.vector(new double[] {50, 50})))); + + VectorQuery vectorQuery = + collection + .whereEqualTo("foo", "bar") + .whereEqualTo("testId", getUniqueTestId()) + .findNearest( + "embedding", new double[] {10, 10}, 3, VectorQuery.DistanceMeasure.EUCLIDEAN); + + VectorQuerySnapshot snapshot = vectorQuery.get().get(); + + assertThat(snapshot.size()).isEqualTo(2); + + assertThat(snapshot.getDocuments().get(0).get("docId")).isEqualTo(getUniqueDocId("c")); + assertThat(snapshot.getDocuments().get(1).get("docId")).isEqualTo(getUniqueDocId("d")); + } + + @Test + public void findNearestOnNonExistentField() throws Exception { + CollectionReference collection = + testCollectionWithDocs( + map( + "a", + map("foo", "bar"), + "b", + map("foo", "bar", "otherField", Arrays.asList(10.0, 10.0)), + "c", + map("foo", "bar", "otherField", "not actually a vector"), + "d", + map("foo", "bar", "otherField", null))); + + VectorQuery vectorQuery = + collection + .whereEqualTo("foo", "bar") + .whereEqualTo("testId", getUniqueTestId()) + .findNearest( + "embedding", new double[] {10, 10}, 3, VectorQuery.DistanceMeasure.EUCLIDEAN); + + VectorQuerySnapshot snapshot = vectorQuery.get().get(); + + assertThat(snapshot.size()).isEqualTo(0); + } + + @Test + public void findNearestOnVectorNestedInAMap() throws Exception { + CollectionReference collection = + testCollectionWithDocs( + map( + "a", + map("nested", map("foo", "bar")), + "b", + map( + "nested", + map("foo", "xxx", "embedding", FieldValue.vector(new double[] {10, 10}))), + "c", + map( + "nested", + map("foo", "bar", "embedding", FieldValue.vector(new double[] {1, 1}))), + "d", + map( + "nested", + map("foo", "bar", "embedding", FieldValue.vector(new double[] {10, 0}))), + "e", + map( + "nested", + map("foo", "bar", "embedding", FieldValue.vector(new double[] {20, 0}))), + "f", + map( + "nested", + map("foo", "bar", "embedding", FieldValue.vector(new double[] {100, 100}))))); + + VectorQuery vectorQuery = + collection + .whereEqualTo("testId", getUniqueTestId()) + .findNearest( + "nested.embedding", + new double[] {10, 10}, + 3, + VectorQuery.DistanceMeasure.EUCLIDEAN); + + VectorQuerySnapshot snapshot = vectorQuery.get().get(); + + assertThat(snapshot.size()).isEqualTo(3); + + assertThat(snapshot.getDocuments().get(0).get("docId")).isEqualTo(getUniqueDocId("b")); + assertThat(snapshot.getDocuments().get(1).get("docId")).isEqualTo(getUniqueDocId("d")); + assertThat(snapshot.getDocuments().get(2).get("docId")).isEqualTo(getUniqueDocId("c")); + } + + @Test + public void findNearestWithSelectToExcludeVectorDataInResponse() throws Exception { + CollectionReference collection = + testCollectionWithDocs( + map( + "a", + map("foo", 1), + "b", + map("foo", 2, "embedding", FieldValue.vector(new double[] {10, 10})), + "c", + map("foo", 3, "embedding", FieldValue.vector(new double[] {1, 1})), + "d", + map("foo", 4, "embedding", FieldValue.vector(new double[] {10, 0})), + "e", + map("foo", 5, "embedding", FieldValue.vector(new double[] {20, 0})), + "f", + map("foo", 6, "embedding", FieldValue.vector(new double[] {100, 100})))); + + VectorQuery vectorQuery = + collection + .whereIn("foo", Arrays.asList(1, 2, 3, 4, 5, 6)) + .whereEqualTo("testId", getUniqueTestId()) + .select("foo", "docId") + .findNearest( + "embedding", new double[] {10, 10}, 10, VectorQuery.DistanceMeasure.EUCLIDEAN); + + VectorQuerySnapshot snapshot = vectorQuery.get().get(); + + assertThat(snapshot.size()).isEqualTo(5); + + assertThat(snapshot.getDocuments().get(0).get("docId")).isEqualTo(getUniqueDocId("b")); + assertThat(snapshot.getDocuments().get(1).get("docId")).isEqualTo(getUniqueDocId("d")); + assertThat(snapshot.getDocuments().get(2).get("docId")).isEqualTo(getUniqueDocId("c")); + assertThat(snapshot.getDocuments().get(3).get("docId")).isEqualTo(getUniqueDocId("e")); + assertThat(snapshot.getDocuments().get(4).get("docId")).isEqualTo(getUniqueDocId("f")); + + for (QueryDocumentSnapshot doc : snapshot.getDocuments()) { + assertThat(doc.get("embedding")).isNull(); + } + } + + @Test + public void findNearestLimits() throws Exception { + double[] embeddingVector = new double[2048]; + double[] queryVector = new double[2048]; + for (int i = 0; i < 2048; i++) { + embeddingVector[i] = i + 1; + queryVector[i] = i - 1; + } + + CollectionReference collection = + testCollectionWithDocs(map("a", map("embedding", FieldValue.vector(embeddingVector)))); + + VectorQuery vectorQuery = + collection + .whereEqualTo("testId", getUniqueTestId()) + .findNearest("embedding", queryVector, 1000, VectorQuery.DistanceMeasure.EUCLIDEAN); + + VectorQuerySnapshot snapshot = vectorQuery.get().get(); + + assertThat(snapshot.size()).isEqualTo(1); + + assertThat(((VectorValue) snapshot.getDocuments().get(0).get("embedding")).toArray()) + .isEqualTo(embeddingVector); + } + + @Test + public void requestingComputedCosineDistance() throws Exception { + CollectionReference collection = + testCollectionWithDocs( + map( + "1", + map("foo", "bar"), + "2", + map("foo", "bar", "embedding", FieldValue.vector(new double[] {1, 0})), + "3", + map("foo", "bar", "embedding", FieldValue.vector(new double[] {0, 1})), + "4", + map("foo", "bar", "embedding", FieldValue.vector(new double[] {0, -0.1})), + "5", + map("foo", "bar", "embedding", FieldValue.vector(new double[] {-1, 0})))); + + VectorQuery vectorQuery = + collection + .whereEqualTo("testId", getUniqueTestId()) + .findNearest( + "embedding", + new double[] {1, 0}, + 5, + VectorQuery.DistanceMeasure.COSINE, + VectorQueryOptions.newBuilder().setDistanceResultField("distance").build()); + + VectorQuerySnapshot snapshot = vectorQuery.get().get(); + + assertThat(snapshot.size()).isEqualTo(4); + + assertThat(snapshot.getDocuments().get(0).get("docId")).isEqualTo(getUniqueDocId("2")); + assertThat(snapshot.getDocuments().get(0).getDouble("distance")).isEqualTo(0); + + assertThat(snapshot.getDocuments().get(1).getDouble("distance")).isEqualTo(1); + assertThat(snapshot.getDocuments().get(2).getDouble("distance")).isEqualTo(1); + + assertThat(snapshot.getDocuments().get(3).get("docId")).isEqualTo(getUniqueDocId("5")); + assertThat(snapshot.getDocuments().get(3).getDouble("distance")).isEqualTo(2); + } + + @Test + public void requestingComputedEuclideanDistance() throws Exception { + CollectionReference collection = + testCollectionWithDocs( + map( + "1", + map("foo", "bar"), + "2", + map("foo", "bar", "embedding", FieldValue.vector(new double[] {2, 0})), + "3", + map("foo", "bar", "embedding", FieldValue.vector(new double[] {1, 100})), + "4", + map("foo", "bar", "embedding", FieldValue.vector(new double[] {1, -0.1})), + "5", + map("foo", "bar", "embedding", FieldValue.vector(new double[] {4, 4})))); + + VectorQuery vectorQuery = + collection + .whereEqualTo("testId", getUniqueTestId()) + .findNearest( + "embedding", + new double[] {1, 0}, + 5, + VectorQuery.DistanceMeasure.EUCLIDEAN, + VectorQueryOptions.newBuilder().setDistanceResultField("distance").build()); + + VectorQuerySnapshot snapshot = vectorQuery.get().get(); + + assertThat(snapshot.size()).isEqualTo(4); + + assertThat(snapshot.getDocuments().get(0).get("docId")).isEqualTo(getUniqueDocId("4")); + assertThat(snapshot.getDocuments().get(0).getDouble("distance")).isEqualTo(0.1); + + assertThat(snapshot.getDocuments().get(1).get("docId")).isEqualTo(getUniqueDocId("2")); + assertThat(snapshot.getDocuments().get(1).getDouble("distance")).isEqualTo(1); + + assertThat(snapshot.getDocuments().get(2).get("docId")).isEqualTo(getUniqueDocId("5")); + assertThat(snapshot.getDocuments().get(2).getDouble("distance")).isEqualTo(5); + + assertThat(snapshot.getDocuments().get(3).get("docId")).isEqualTo(getUniqueDocId("3")); + assertThat(snapshot.getDocuments().get(3).getDouble("distance")).isEqualTo(100); + } + + @Test + public void requestingComputedDotProductDistance() throws Exception { + CollectionReference collection = + testCollectionWithDocs( + map( + "1", + map("foo", "bar"), + "2", + map("foo", "bar", "embedding", FieldValue.vector(new double[] {2, 0})), + "3", + map("foo", "bar", "embedding", FieldValue.vector(new double[] {1, 100})), + "4", + map("foo", "bar", "embedding", FieldValue.vector(new double[] {-20, 0})), + "5", + map("foo", "bar", "embedding", FieldValue.vector(new double[] {0.1, 4})))); + + VectorQuery vectorQuery = + collection + .whereEqualTo("testId", getUniqueTestId()) + .findNearest( + "embedding", + new double[] {1, 0}, + 5, + VectorQuery.DistanceMeasure.DOT_PRODUCT, + VectorQueryOptions.newBuilder().setDistanceResultField("distance").build()); + + VectorQuerySnapshot snapshot = vectorQuery.get().get(); + + assertThat(snapshot.size()).isEqualTo(4); + + assertThat(snapshot.getDocuments().get(0).get("docId")).isEqualTo(getUniqueDocId("2")); + assertThat(snapshot.getDocuments().get(0).getDouble("distance")).isEqualTo(2); + + assertThat(snapshot.getDocuments().get(1).get("docId")).isEqualTo(getUniqueDocId("3")); + assertThat(snapshot.getDocuments().get(1).getDouble("distance")).isEqualTo(1); + + assertThat(snapshot.getDocuments().get(2).get("docId")).isEqualTo(getUniqueDocId("5")); + assertThat(snapshot.getDocuments().get(2).getDouble("distance")).isEqualTo(0.1); + + assertThat(snapshot.getDocuments().get(3).get("docId")).isEqualTo(getUniqueDocId("4")); + assertThat(snapshot.getDocuments().get(3).getDouble("distance")).isEqualTo(-20); + } + + @Test + public void overwritesDistanceResultFieldOnConflict() throws Exception { + CollectionReference collection = + testCollectionWithDocs( + map( + "1", + map( + "foo", + "bar", + "embedding", + FieldValue.vector(new double[] {0, 1}), + "distance", + "100 miles"))); + + VectorQuery vectorQuery = + collection + .whereEqualTo("testId", getUniqueTestId()) + .findNearest( + "embedding", + new double[] {1, 0}, + 5, + VectorQuery.DistanceMeasure.COSINE, + VectorQueryOptions.newBuilder().setDistanceResultField("distance").build()); + + VectorQuerySnapshot snapshot = vectorQuery.get().get(); + + assertThat(snapshot.size()).isEqualTo(1); + + assertThat(snapshot.getDocuments().get(0).get("docId")).isEqualTo(getUniqueDocId("1")); + assertThat(snapshot.getDocuments().get(0).getDouble("distance")).isEqualTo(1); + } + + @Test + public void requestingComputedDistanceInSelectQueries() throws Exception { + CollectionReference collection = + testCollectionWithDocs( + map( + "1", + map("foo", "bar"), + "2", + map("foo", "bar", "embedding", FieldValue.vector(new double[] {1, 0})), + "3", + map("foo", "bar", "embedding", FieldValue.vector(new double[] {0, 1})), + "4", + map("foo", "bar", "embedding", FieldValue.vector(new double[] {0, -0.1})), + "5", + map("foo", "bar", "embedding", FieldValue.vector(new double[] {-1, 0})))); + + VectorQuery vectorQuery = + collection + .whereEqualTo("testId", getUniqueTestId()) + .select("embedding", "distance", "docId") + .findNearest( + "embedding", + new double[] {1, 0}, + 5, + VectorQuery.DistanceMeasure.COSINE, + VectorQueryOptions.newBuilder().setDistanceResultField("distance").build()); + + VectorQuerySnapshot snapshot = vectorQuery.get().get(); + + assertThat(snapshot.size()).isEqualTo(4); + + assertThat(snapshot.getDocuments().get(0).get("docId")).isEqualTo(getUniqueDocId("2")); + assertThat(snapshot.getDocuments().get(0).getDouble("distance")).isEqualTo(0); + + assertThat(snapshot.getDocuments().get(1).getDouble("distance")).isEqualTo(1); + assertThat(snapshot.getDocuments().get(2).getDouble("distance")).isEqualTo(1); + + assertThat(snapshot.getDocuments().get(3).get("docId")).isEqualTo(getUniqueDocId("5")); + assertThat(snapshot.getDocuments().get(3).getDouble("distance")).isEqualTo(2); + } + + @Test + public void queryingWithDistanceThresholdUsingCosineDistance() throws Exception { + CollectionReference collection = + testCollectionWithDocs( + map( + "1", + map("foo", "bar"), + "2", + map("foo", "bar", "embedding", FieldValue.vector(new double[] {1, 0})), + "3", + map("foo", "bar", "embedding", FieldValue.vector(new double[] {1, 1})), + "4", + map("foo", "bar", "embedding", FieldValue.vector(new double[] {0, -0.1})), + "5", + map("foo", "bar", "embedding", FieldValue.vector(new double[] {-1, 0})))); + + VectorQuery vectorQuery = + collection + .whereEqualTo("testId", getUniqueTestId()) + .findNearest( + "embedding", + new double[] {1, 0}, + 5, + VectorQuery.DistanceMeasure.COSINE, + VectorQueryOptions.newBuilder().setDistanceThreshold(1.0).build()); + + VectorQuerySnapshot snapshot = vectorQuery.get().get(); + + assertThat(snapshot.size()).isEqualTo(3); + + assertThat(snapshot.getDocuments().get(0).get("docId")).isEqualTo(getUniqueDocId("2")); + assertThat(snapshot.getDocuments().get(1).get("docId")).isEqualTo(getUniqueDocId("3")); + assertThat(snapshot.getDocuments().get(2).get("docId")).isEqualTo(getUniqueDocId("4")); + } + + @Test + public void queryingWithDistanceThresholdUsingEuclideanDistance() throws Exception { + CollectionReference collection = + testCollectionWithDocs( + map( + "1", + map("foo", "bar"), + "2", + map("foo", "bar", "embedding", FieldValue.vector(new double[] {2, 0})), + "3", + map("foo", "bar", "embedding", FieldValue.vector(new double[] {1, 100})), + "4", + map("foo", "bar", "embedding", FieldValue.vector(new double[] {1, -0.1})), + "5", + map("foo", "bar", "embedding", FieldValue.vector(new double[] {4, 4})))); + + VectorQuery vectorQuery = + collection + .whereEqualTo("testId", getUniqueTestId()) + .findNearest( + "embedding", + new double[] {1, 0}, + 5, + VectorQuery.DistanceMeasure.EUCLIDEAN, + VectorQueryOptions.newBuilder().setDistanceThreshold(5.0).build()); + + VectorQuerySnapshot snapshot = vectorQuery.get().get(); + + assertThat(snapshot.size()).isEqualTo(3); + + assertThat(snapshot.getDocuments().get(0).get("docId")).isEqualTo(getUniqueDocId("4")); + assertThat(snapshot.getDocuments().get(1).get("docId")).isEqualTo(getUniqueDocId("2")); + assertThat(snapshot.getDocuments().get(2).get("docId")).isEqualTo(getUniqueDocId("5")); + } + + @Test + public void queryingWithDistanceThresholdUsingDotProductDistance() throws Exception { + CollectionReference collection = + testCollectionWithDocs( + map( + "1", + map("foo", "bar"), + "2", + map("foo", "bar", "embedding", FieldValue.vector(new double[] {2, 0})), + "3", + map("foo", "bar", "embedding", FieldValue.vector(new double[] {1, 100})), + "4", + map("foo", "bar", "embedding", FieldValue.vector(new double[] {-20, 0})), + "5", + map("foo", "bar", "embedding", FieldValue.vector(new double[] {0.1, 4})))); + + VectorQuery vectorQuery = + collection + .whereEqualTo("testId", getUniqueTestId()) + .findNearest( + "embedding", + new double[] {1, 0}, + 5, + VectorQuery.DistanceMeasure.DOT_PRODUCT, + VectorQueryOptions.newBuilder().setDistanceThreshold(1.0).build()); + + VectorQuerySnapshot snapshot = vectorQuery.get().get(); + + assertThat(snapshot.size()).isEqualTo(2); + + assertThat(snapshot.getDocuments().get(0).get("docId")).isEqualTo(getUniqueDocId("2")); + assertThat(snapshot.getDocuments().get(1).get("docId")).isEqualTo(getUniqueDocId("3")); + } + + @Test + public void queryWithDistanceResultFieldAndDistanceThreshold() throws Exception { + CollectionReference collection = + testCollectionWithDocs( + map( + "1", + map("foo", "bar"), + "2", + map("foo", "bar", "embedding", FieldValue.vector(new double[] {2, 0})), + "3", + map("foo", "bar", "embedding", FieldValue.vector(new double[] {1, 100})), + "4", + map("foo", "bar", "embedding", FieldValue.vector(new double[] {-20, 0})), + "5", + map("foo", "bar", "embedding", FieldValue.vector(new double[] {0.1, 4})))); + + VectorQuery vectorQuery = + collection + .whereEqualTo("testId", getUniqueTestId()) + .findNearest( + "embedding", + new double[] {1, 0}, + 5, + VectorQuery.DistanceMeasure.DOT_PRODUCT, + VectorQueryOptions.newBuilder() + .setDistanceThreshold(0.11) + .setDistanceResultField("foo") + .build()); + + VectorQuerySnapshot snapshot = vectorQuery.get().get(); + + assertThat(snapshot.size()).isEqualTo(2); + + assertThat(snapshot.getDocuments().get(0).get("docId")).isEqualTo(getUniqueDocId("2")); + assertThat(snapshot.getDocuments().get(0).getDouble("foo")).isEqualTo(2); + + assertThat(snapshot.getDocuments().get(1).get("docId")).isEqualTo(getUniqueDocId("3")); + assertThat(snapshot.getDocuments().get(1).getDouble("foo")).isEqualTo(1); + } + + @Test + public void queryWithDistanceResultFieldAndDistanceThresholdWithFirestoreTypes() + throws Exception { + CollectionReference collection = + testCollectionWithDocs( + map( + "1", + map("foo", "bar"), + "2", + map("foo", "bar", "embedding", FieldValue.vector(new double[] {2, 0})), + "3", + map("foo", "bar", "embedding", FieldValue.vector(new double[] {1, 100})), + "4", + map("foo", "bar", "embedding", FieldValue.vector(new double[] {-20, 0})), + "5", + map("foo", "bar", "embedding", FieldValue.vector(new double[] {0.1, 4})))); + + VectorQuery vectorQuery = + collection + .whereEqualTo("testId", getUniqueTestId()) + .findNearest( + FieldPath.of("embedding"), + FieldValue.vector(new double[] {1, 0}), + 5, + VectorQuery.DistanceMeasure.DOT_PRODUCT, + VectorQueryOptions.newBuilder() + .setDistanceThreshold(0.11) + .setDistanceResultField("foo") + .build()); + + VectorQuerySnapshot snapshot = vectorQuery.get().get(); + + assertThat(snapshot.size()).isEqualTo(2); + + assertThat(snapshot.getDocuments().get(0).get("docId")).isEqualTo(getUniqueDocId("2")); + assertThat(snapshot.getDocuments().get(0).getDouble("foo")).isEqualTo(2); + + assertThat(snapshot.getDocuments().get(1).get("docId")).isEqualTo(getUniqueDocId("3")); + assertThat(snapshot.getDocuments().get(1).getDouble("foo")).isEqualTo(1); + } + + @Test + public void willNotExceedLimitEvenIfThereAreMoreResultsMoreSimilarThanDistanceThreshold() + throws Exception { + CollectionReference collection = + testCollectionWithDocs( + map( + "1", + map("foo", "bar"), + "2", + map("foo", "bar", "embedding", FieldValue.vector(new double[] {2, 0})), + "3", + map("foo", "bar", "embedding", FieldValue.vector(new double[] {1, 100})), + "4", + map("foo", "bar", "embedding", FieldValue.vector(new double[] {-20, 0})), + "5", + map("foo", "bar", "embedding", FieldValue.vector(new double[] {0.1, 4})))); + + VectorQuery vectorQuery = + collection + .whereEqualTo("testId", getUniqueTestId()) + .findNearest( + "embedding", + new double[] {1, 0}, + 2, // limit set to 2 + VectorQuery.DistanceMeasure.DOT_PRODUCT, + VectorQueryOptions.newBuilder().setDistanceThreshold(0.0).build()); + + VectorQuerySnapshot snapshot = vectorQuery.get().get(); + + assertThat(snapshot.size()).isEqualTo(2); + + assertThat(snapshot.getDocuments().get(0).get("docId")).isEqualTo(getUniqueDocId("2")); + assertThat(snapshot.getDocuments().get(1).get("docId")).isEqualTo(getUniqueDocId("3")); + } + + @Test + public void testVectorQueryPlan() throws Exception { + CollectionReference collection = + testCollectionWithDocs( + map( + "1", + map("foo", "bar"), + "2", + map("foo", "bar", "embedding", FieldValue.vector(new double[] {2, 0})), + "3", + map("foo", "bar", "embedding", FieldValue.vector(new double[] {1, 100})), + "4", + map("foo", "bar", "embedding", FieldValue.vector(new double[] {-20, 0})), + "5", + map("foo", "bar", "embedding", FieldValue.vector(new double[] {0.1, 4})))); + + VectorQuery vectorQuery = + collection + .whereEqualTo("testId", getUniqueTestId()) + .findNearest( + "embedding", new double[] {1, 0}, 5, VectorQuery.DistanceMeasure.DOT_PRODUCT); + + ExplainResults explainResults = + vectorQuery.explain(ExplainOptions.builder().setAnalyze(false).build()).get(); + + @Nullable VectorQuerySnapshot snapshot = explainResults.getSnapshot(); + assertThat(snapshot).isNull(); + + ExplainMetrics metrics = explainResults.getMetrics(); + assertThat(metrics).isNotNull(); + + PlanSummary planSummary = metrics.getPlanSummary(); + assertThat(planSummary).isNotNull(); + assertThat(planSummary.getIndexesUsed()).isNotEmpty(); + + ExecutionStats stats = metrics.getExecutionStats(); + assertThat(stats).isNull(); + } + + @Test + public void testVectorQueryProfile() throws Exception { + CollectionReference collection = + testCollectionWithDocs( + map( + "1", + map("foo", "bar"), + "2", + map("foo", "bar", "embedding", FieldValue.vector(new double[] {2, 0})), + "3", + map("foo", "bar", "embedding", FieldValue.vector(new double[] {1, 100})), + "4", + map("foo", "bar", "embedding", FieldValue.vector(new double[] {-20, 0})), + "5", + map("foo", "bar", "embedding", FieldValue.vector(new double[] {0.1, 4})))); + + VectorQuery vectorQuery = + collection + .whereEqualTo("testId", getUniqueTestId()) + .findNearest( + "embedding", new double[] {1, 0}, 5, VectorQuery.DistanceMeasure.DOT_PRODUCT); + + ExplainResults explainResults = + vectorQuery.explain(ExplainOptions.builder().setAnalyze(true).build()).get(); + + @Nullable VectorQuerySnapshot snapshot = explainResults.getSnapshot(); + assertThat(snapshot).isNotNull(); + assertThat(snapshot.size()).isEqualTo(4); + + ExplainMetrics metrics = explainResults.getMetrics(); + assertThat(metrics).isNotNull(); + + PlanSummary planSummary = metrics.getPlanSummary(); + assertThat(planSummary).isNotNull(); + assertThat(planSummary.getIndexesUsed()).isNotEmpty(); + + ExecutionStats stats = metrics.getExecutionStats(); + assertThat(stats).isNotNull(); + assertThat(stats.getDebugStats()).isNotEmpty(); + assertThat(stats.getReadOperations()).isEqualTo(5); + assertThat(stats.getResultsReturned()).isEqualTo(4); + assertThat(stats.getExecutionDuration()).isGreaterThan(Duration.ZERO); + } + + @Test + public void vectorQuerySnapshotReturnsVectorQuery() throws Exception { + CollectionReference collection = + testCollectionWithDocs( + map( + "a", map("foo", "bar"), + "b", map("foo", "xxx", "embedding", FieldValue.vector(new double[] {10, 10})), + "c", map("foo", "bar", "embedding", FieldValue.vector(new double[] {1, 1})), + "d", map("foo", "bar", "embedding", FieldValue.vector(new double[] {10, 0})), + "e", map("foo", "bar", "embedding", FieldValue.vector(new double[] {20, 0})), + "f", map("foo", "bar", "embedding", FieldValue.vector(new double[] {100, 100})))); + + VectorQuery vectorQuery = + collection + .whereEqualTo("foo", "bar") + .whereEqualTo("testId", getUniqueTestId()) + .findNearest( + "embedding", new double[] {10, 10}, 3, VectorQuery.DistanceMeasure.EUCLIDEAN); + + VectorQuerySnapshot snapshot = vectorQuery.get().get(); + + assertTrue(snapshot.getQuery() == vectorQuery); + } +}