Skip to content

Commit

Permalink
feat: add support for the VectorValue type (#1716)
Browse files Browse the repository at this point in the history
Implement VectorValue type support.
  • Loading branch information
MarkDuckworth authored Aug 12, 2024
1 parent 4384970 commit 81bfa0d
Show file tree
Hide file tree
Showing 11 changed files with 509 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,8 @@ private static <T> Object serialize(T o, ErrorPath path) {
|| o instanceof Blob
|| o instanceof DocumentReference
|| o instanceof FieldValue
|| o instanceof Value) {
|| o instanceof Value
|| o instanceof VectorValue) {
return o;
} else if (o instanceof Instant) {
Instant instant = (Instant) o;
Expand Down Expand Up @@ -243,6 +244,8 @@ private static <T> T deserializeToClass(Object o, Class<T> clazz, DeserializeCon
return (T) convertBlob(o, context);
} else if (GeoPoint.class.isAssignableFrom(clazz)) {
return (T) convertGeoPoint(o, context);
} else if (VectorValue.class.isAssignableFrom(clazz)) {
return (T) convertVectorValue(o, context);
} else if (DocumentReference.class.isAssignableFrom(clazz)) {
return (T) convertDocumentReference(o, context);
} else if (clazz.isArray()) {
Expand Down Expand Up @@ -596,6 +599,16 @@ private static GeoPoint convertGeoPoint(Object o, DeserializeContext context) {
}
}

private static VectorValue convertVectorValue(Object o, DeserializeContext context) {
if (o instanceof VectorValue) {
return (VectorValue) o;
} else {
throw deserializeError(
context.errorPath,
"Failed to convert value of type " + o.getClass().getName() + " to VectorValue");
}
}

private static DocumentReference convertDocumentReference(Object o, DeserializeContext context) {
if (o instanceof DocumentReference) {
return (DocumentReference) o;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -385,6 +385,18 @@ public GeoPoint getGeoPoint(@Nonnull String field) {
return (GeoPoint) get(field);
}

/**
* Returns the value of the field as a VectorValue.
*
* @param field The path to the field.
* @throws RuntimeException if the value is not a VectorValue.
* @return The value of the field.
*/
@Nullable
public VectorValue getVectorValue(@Nonnull String field) {
return (VectorValue) get(field);
}

/**
* Gets the reference to the document.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,17 @@ public static FieldValue arrayRemove(@Nonnull Object... elements) {
return new ArrayRemoveFieldValue(Arrays.asList(elements));
}

/**
* Creates a new {@link VectorValue} constructed with a copy of the given array of doubles.
*
* @param values Create a {@link VectorValue} instance with a copy of this array of doubles.
* @return A new {@link VectorValue} constructed with a copy of the given array of doubles.
*/
@Nonnull
public static VectorValue vector(@Nonnull double[] values) {
return new VectorValue(values);
}

/** Whether this FieldTransform should be included in the document mask. */
abstract boolean includeInDocumentMask();

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
/*
* Copyright 2024 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.google.cloud.firestore;

abstract class MapType {
static final String RESERVED_MAP_KEY = "__type__";
static final String RESERVED_MAP_KEY_VECTOR_VALUE = "__vector__";
static final String VECTOR_MAP_VECTORS_KEY = "value";
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,11 @@

package com.google.cloud.firestore;

import com.google.firestore.v1.MapValue;
import com.google.firestore.v1.Value;
import com.google.firestore.v1.Value.ValueTypeCase;
import com.google.protobuf.ByteString;
import java.util.Collections;
import java.util.Comparator;
import java.util.Iterator;
import java.util.List;
Expand All @@ -40,6 +42,7 @@ enum TypeOrder implements Comparable<TypeOrder> {
REF,
GEO_POINT,
ARRAY,
VECTOR,
OBJECT;

static TypeOrder fromValue(Value value) {
Expand All @@ -65,13 +68,24 @@ static TypeOrder fromValue(Value value) {
case ARRAY_VALUE:
return ARRAY;
case MAP_VALUE:
return OBJECT;
return fromMapValue(value.getMapValue());
default:
throw new IllegalArgumentException("Could not detect value type for " + value);
}
}
}

static TypeOrder fromMapValue(MapValue mapValue) {
switch (UserDataConverter.detectMapRepresentation(mapValue)) {
case VECTOR_VALUE:
return TypeOrder.VECTOR;
case UNKNOWN:
case NONE:
default:
return TypeOrder.OBJECT;
}
}

static final Order INSTANCE = new Order();

private Order() {}
Expand Down Expand Up @@ -113,6 +127,8 @@ public int compare(@Nonnull Value left, @Nonnull Value right) {
left.getArrayValue().getValuesList(), right.getArrayValue().getValuesList());
case OBJECT:
return compareObjects(left, right);
case VECTOR:
return compareVectors(left, right);
default:
throw new IllegalArgumentException("Cannot compare " + leftType);
}
Expand Down Expand Up @@ -209,6 +225,30 @@ private int compareObjects(Value left, Value right) {
return Boolean.compare(leftIterator.hasNext(), rightIterator.hasNext());
}

private int compareVectors(Value left, Value right) {
// The vector is a map, but only vector value is compared.
Value leftValueField =
left.getMapValue().getFieldsOrDefault(MapType.VECTOR_MAP_VECTORS_KEY, null);
Value rightValueField =
right.getMapValue().getFieldsOrDefault(MapType.VECTOR_MAP_VECTORS_KEY, null);

List<Value> leftArray =
(leftValueField != null)
? leftValueField.getArrayValue().getValuesList()
: Collections.emptyList();
List<Value> rightArray =
(rightValueField != null)
? rightValueField.getArrayValue().getValuesList()
: Collections.emptyList();

Integer lengthCompare = Long.compare(leftArray.size(), rightArray.size());
if (lengthCompare != 0) {
return lengthCompare;
}

return compareArrays(leftArray, rightArray);
}

private int compareNumbers(Value left, Value right) {
if (left.getValueTypeCase() == ValueTypeCase.DOUBLE_VALUE) {
if (right.getValueTypeCase() == ValueTypeCase.DOUBLE_VALUE) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import com.google.common.base.Preconditions;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import com.google.common.primitives.Doubles;
import com.google.firestore.v1.ArrayValue;
import com.google.firestore.v1.MapValue;
import com.google.firestore.v1.Value;
Expand All @@ -32,10 +33,12 @@
import java.util.List;
import java.util.Map;
import java.util.concurrent.TimeUnit;
import java.util.logging.Logger;
import javax.annotation.Nullable;

/** Converts user input into the Firestore Value representation. */
class UserDataConverter {
private static final Logger LOGGER = Logger.getLogger(UserDataConverter.class.getName());

/** Controls the behavior for field deletes. */
interface EncodingOptions {
Expand Down Expand Up @@ -183,12 +186,34 @@ static Value encodeValue(
// send the map.
return null;
}
} else if (sanitizedObject instanceof VectorValue) {
VectorValue vectorValue = (VectorValue) sanitizedObject;
return Value.newBuilder().setMapValue(vectorValue.toProto()).build();
}

throw FirestoreException.forInvalidArgument(
"Cannot convert %s to Firestore Value", sanitizedObject);
}

static MapValue encodeVector(double[] rawVector) {
MapValue.Builder res = MapValue.newBuilder();

res.putFields(
MapType.RESERVED_MAP_KEY,
encodeValue(
FieldPath.fromDotSeparatedString(MapType.RESERVED_MAP_KEY),
MapType.RESERVED_MAP_KEY_VECTOR_VALUE,
ARGUMENT));
res.putFields(
MapType.VECTOR_MAP_VECTORS_KEY,
encodeValue(
FieldPath.fromDotSeparatedString(MapType.RESERVED_MAP_KEY_VECTOR_VALUE),
Doubles.asList(rawVector),
ARGUMENT));

return res.build();
}

static Object decodeValue(FirestoreRpcContext<?> rpcContext, Value v) {
Value.ValueTypeCase typeCase = v.getValueTypeCase();
switch (typeCase) {
Expand Down Expand Up @@ -220,18 +245,72 @@ static Object decodeValue(FirestoreRpcContext<?> rpcContext, Value v) {
}
return list;
case MAP_VALUE:
return decodeMap(rpcContext, v.getMapValue());
default:
throw FirestoreException.forInvalidArgument(
String.format("Unknown Value Type: %s", typeCase));
}
}

static Object decodeMap(FirestoreRpcContext<?> rpcContext, MapValue mapValue) {
MapRepresentation mapRepresentation = detectMapRepresentation(mapValue);
Map<String, Value> inputMap = mapValue.getFieldsMap();
switch (mapRepresentation) {
case UNKNOWN:
LOGGER.warning(
"Parsing unknown map type as generic map. This map type may be supported in a newer SDK version.");
case NONE:
Map<String, Object> outputMap = new HashMap<>();
Map<String, Value> inputMap = v.getMapValue().getFieldsMap();
for (Map.Entry<String, Value> entry : inputMap.entrySet()) {
outputMap.put(entry.getKey(), decodeValue(rpcContext, entry.getValue()));
}
return outputMap;
case VECTOR_VALUE:
double[] values =
inputMap.get(MapType.VECTOR_MAP_VECTORS_KEY).getArrayValue().getValuesList().stream()
.mapToDouble(val -> val.getDoubleValue())
.toArray();
return new VectorValue(values);
default:
throw FirestoreException.forInvalidArgument(
String.format("Unknown Value Type: %s", typeCase));
String.format("Unsupported MapRepresentation: %s", mapRepresentation));
}
}

/** Indicates the data type represented by a MapValue. */
enum MapRepresentation {
/** The MapValue represents an unknown data type. */
UNKNOWN,
/** The MapValue does not represent any special data type. */
NONE,
/** The MapValue represents a VectorValue. */
VECTOR_VALUE
}

static MapRepresentation detectMapRepresentation(MapValue mapValue) {
Map<String, Value> fields = mapValue.getFieldsMap();
if (!fields.containsKey(MapType.RESERVED_MAP_KEY)) {
return MapRepresentation.NONE;
}

Value typeValue = fields.get(MapType.RESERVED_MAP_KEY);
if (typeValue.getValueTypeCase() != Value.ValueTypeCase.STRING_VALUE) {
LOGGER.warning(
"Unable to parse __type__ field of map. Unsupported value type: "
+ typeValue.getValueTypeCase().toString());
return MapRepresentation.UNKNOWN;
}

String typeString = typeValue.getStringValue();

if (typeString.equals(MapType.RESERVED_MAP_KEY_VECTOR_VALUE)) {
return MapRepresentation.VECTOR_VALUE;
}

LOGGER.warning("Unsupported __type__ value for map: " + typeString);
return MapRepresentation.UNKNOWN;
}

static Object decodeGoogleProtobufValue(com.google.protobuf.Value v) {
switch (v.getKindCase()) {
case NULL_VALUE:
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
/*
* Copyright 2024 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.google.cloud.firestore;

import com.google.firestore.v1.MapValue;
import java.io.Serializable;
import java.util.Arrays;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;

/**
* Represents a vector in Firestore documents. Create an instance with {@link FieldValue#vector}.
*/
public final class VectorValue implements Serializable {
private final double[] values;

VectorValue(@Nullable double[] values) {
if (values == null) this.values = new double[] {};
else this.values = values.clone();
}

/**
* Returns a representation of the vector as an array of doubles.
*
* @return A representation of the vector as an array of doubles
*/
@Nonnull
public double[] toArray() {
return this.values.clone();
}

/**
* Returns true if this VectorValue is equal to the provided object.
*
* @param obj The object to compare against.
* @return Whether this VectorValue is equal to the provided object.
*/
@Override
public boolean equals(Object obj) {
if (this == obj) {
return true;
}
if (obj == null || getClass() != obj.getClass()) {
return false;
}
VectorValue otherArray = (VectorValue) obj;
return Arrays.equals(this.values, otherArray.values);
}

@Override
public int hashCode() {
return Arrays.hashCode(values);
}

MapValue toProto() {
return UserDataConverter.encodeVector(this.values);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -880,7 +880,8 @@ public void extractFieldMaskFromMerge() throws Exception {
"second.objectValue.foo",
"second.timestampValue",
"second.trueValue",
"second.model.foo");
"second.model.foo",
"second.vectorValue");

CommitRequest expectedCommit = commit(set(nestedUpdate, updateMask));
assertCommitEquals(expectedCommit, commitCapture.getValue());
Expand Down
Loading

0 comments on commit 81bfa0d

Please sign in to comment.