Skip to content

Commit

Permalink
feat(firestore): Adding distance threshold and result field (#10802)
Browse files Browse the repository at this point in the history
* feat(firestore): Adding distance threshold and result field

* refactor(firestore): Renaming method names

* refactor(firestore): Move threshold and result field to options. Rename FindNearestOptions

* refactor(firestore): Rename to FindNearestOptions

* refactor(firestore): Refactoring code
  • Loading branch information
bhshkh authored Sep 11, 2024
1 parent 839f30e commit e9a551e
Show file tree
Hide file tree
Showing 5 changed files with 223 additions and 66 deletions.
5 changes: 4 additions & 1 deletion firestore/examples_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -502,7 +502,10 @@ func ExampleQuery_FindNearest() {

//
q := client.Collection("descriptions").
FindNearest("Embedding", []float32{1, 2, 3}, 5, firestore.DistanceMeasureDotProduct, nil)
FindNearest("Embedding", []float32{1, 2, 3}, 5, firestore.DistanceMeasureDotProduct, &firestore.FindNearestOptions{
DistanceThreshold: firestore.Ptr(20.0),
DistanceResultField: "vector_distance",
})
iter1 := q.Documents(ctx)
_ = iter1 // TODO: Use iter1.
}
Expand Down
5 changes: 3 additions & 2 deletions firestore/fieldpath.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ import (
"cloud.google.com/go/internal/fields"
)

const invalidRunes = "~*/[]"

// A FieldPath is a non-empty sequence of non-empty fields that reference a value.
//
// A FieldPath value should only be necessary if one of the field names contains
Expand All @@ -54,9 +56,8 @@ type FieldPath []string
// including attempts to quote field path compontents. So "a.`b.c`.d" is parsed into
// four parts, "a", "`b", "c`" and "d".
func parseDotSeparatedString(s string) (FieldPath, error) {
const invalidRunes = "~*/[]"
if strings.ContainsAny(s, invalidRunes) {
return nil, fmt.Errorf("firestore: %q contains an invalid rune (one of %s)", s, invalidRunes)
return nil, errInvalidRunesField(s)
}
fp := FieldPath(strings.Split(s, "."))
if err := fp.validate(); err != nil {
Expand Down
97 changes: 69 additions & 28 deletions firestore/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3218,6 +3218,7 @@ func TestIntegration_FindNearest(t *testing.T) {
cancel()
})
queryField := "EmbeddedField64"
resultField := "vector_distance"
indexNames := createVectorIndexes(adminCtx, t, wantDBPath, []vectorIndex{
{
fieldPath: queryField,
Expand All @@ -3229,34 +3230,46 @@ func TestIntegration_FindNearest(t *testing.T) {
})

type coffeeBean struct {
ID string
ID int
EmbeddedField64 Vector64
EmbeddedField32 Vector32
Float32s []float32 // When querying, saving and retrieving, this should be retrieved as []float32 and not Vector32
}

beans := []coffeeBean{
{
ID: "Robusta",
{ // Euclidean Distance from {1, 2, 3} = 0
ID: 0,
EmbeddedField64: []float64{1, 2, 3},
EmbeddedField32: []float32{1, 2, 3},
Float32s: []float32{1, 2, 3},
},
{
ID: "Excelsa",
{ // Euclidean Distance from {1, 2, 3} = 5.19
ID: 1,
EmbeddedField64: []float64{4, 5, 6},
EmbeddedField32: []float32{4, 5, 6},
Float32s: []float32{4, 5, 6},
},
{ // Euclidean Distance from {1, 2, 3} = 10.39
ID: 2,
EmbeddedField64: []float64{7, 8, 9},
EmbeddedField32: []float32{7, 8, 9},
Float32s: []float32{7, 8, 9},
},
{ // Euclidean Distance from {1, 2, 3} = 15.58
ID: 3,
EmbeddedField64: []float64{10, 11, 12},
EmbeddedField32: []float32{10, 11, 12},
Float32s: []float32{10, 11, 12},
},
{
ID: "Arabica",
// Euclidean Distance from {1, 2, 3} = 370.42
ID: 4,
EmbeddedField64: []float64{100, 200, 300}, // too far from query vector. not within findNearest limit
EmbeddedField32: []float32{100, 200, 300},
Float32s: []float32{100, 200, 300},
},

{
ID: "Liberica",
ID: 5,
EmbeddedField64: []float64{1, 2}, // Not enough dimensions as compared to query vector.
EmbeddedField32: []float32{1, 2},
Float32s: []float32{1, 2},
Expand All @@ -3277,27 +3290,55 @@ func TestIntegration_FindNearest(t *testing.T) {
h.mustCreate(doc, beans[i])
}

// Query documents with a vector field
vectorQuery := collRef.FindNearest(queryField, []float64{1, 2, 3}, 2, DistanceMeasureEuclidean, nil)

iter := vectorQuery.Documents(ctx)
gotDocs, err := iter.GetAll()
if err != nil {
t.Fatalf("GetAll: %+v", err)
}
for _, tc := range []struct {
desc string
vq VectorQuery
wantBeans []coffeeBean
wantResField string
}{
{
desc: "FindNearest without threshold without resultField",
vq: collRef.FindNearest(queryField, []float64{1, 2, 3}, 2, DistanceMeasureEuclidean, nil),
wantBeans: beans[:2],
},
{
desc: "FindNearest threshold and resultField",
vq: collRef.FindNearest(queryField, []float64{1, 2, 3}, 3, DistanceMeasureEuclidean, &FindNearestOptions{
DistanceThreshold: Ptr(20.0),
DistanceResultField: resultField,
}),
wantBeans: beans[:3],
wantResField: resultField,
},
} {
t.Run(tc.desc, func(t *testing.T) {
iter := tc.vq.Documents(ctx)
gotDocs, err := iter.GetAll()
if err != nil {
t.Fatalf("GetAll: %+v", err)
}

if len(gotDocs) != 2 {
t.Fatalf("Expected 2 results, got %d", len(gotDocs))
}
if len(gotDocs) != len(tc.wantBeans) {
t.Fatalf("Expected %v results, got %d", len(tc.wantBeans), len(gotDocs))
}

for i, doc := range gotDocs {
gotBean := coffeeBean{}
err := doc.DataTo(&gotBean)
if err != nil {
t.Errorf("#%v: DataTo: %+v", doc.Ref.ID, err)
}
if beans[i].ID != gotBean.ID {
t.Errorf("#%v: want: %v, got: %v", i, beans[i].ID, gotBean.ID)
}
for i, doc := range gotDocs {
var gotBean coffeeBean
if len(tc.wantResField) != 0 {
_, ok := doc.Data()[tc.wantResField]
if !ok {
t.Errorf("Expected %v field to exist in %v", tc.wantResField, doc.Data())
}
}
err := doc.DataTo(&gotBean)
if err != nil {
t.Errorf("#%v: DataTo: %+v", doc.Ref.ID, err)
continue
}
if tc.wantBeans[i].ID != gotBean.ID {
t.Errorf("#%v: want: %v, got: %v", i, beans[i].ID, gotBean.ID)
}
}
})
}
}
37 changes: 34 additions & 3 deletions firestore/query.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,15 @@ import (
)

var (
errMetricsBeforeEnd = errors.New("firestore: ExplainMetrics are available only after the iterator reaches the end")
errMetricsBeforeEnd = errors.New("firestore: ExplainMetrics are available only after the iterator reaches the end")
errInvalidVector = errors.New("firestore: queryVector must be Vector32 or Vector64")
errMalformedVectorQuery = errors.New("firestore: Malformed VectorQuery. Use FindNearest or FindNearestPath to create VectorQuery")
)

func errInvalidRunesField(field string) error {
return fmt.Errorf("firestore: %q contains an invalid rune (one of %s)", field, invalidRunes)
}

// Query represents a Firestore query.
//
// Query values are immutable. Each Query method creates
Expand Down Expand Up @@ -517,9 +523,27 @@ const (
DistanceMeasureDotProduct DistanceMeasure = DistanceMeasure(pb.StructuredQuery_FindNearest_DOT_PRODUCT)
)

// Ptr returns a pointer to its argument.
// It can be used to initialize pointer fields:
//
// findNearestOptions.DistanceThreshold = firestore.Ptr[float64](0.1)
func Ptr[T any](t T) *T { return &t }

// FindNearestOptions are options for a FindNearest vector query.
// At present, there are no options.
type FindNearestOptions struct {
// DistanceThreshold specifies a threshold for which no less similar documents
// will be returned. The behavior of the specified [DistanceMeasure] will
// affect the meaning of the distance threshold. Since [DistanceMeasureDotProduct]
// distances increase when the vectors are more similar, the comparison is inverted.
// For [DistanceMeasureEuclidean], [DistanceMeasureCosine]: WHERE distance <= distanceThreshold
// For [DistanceMeasureDotProduct]: WHERE distance >= distance_threshold
DistanceThreshold *float64

// DistanceResultField specifies name of the document field to output the result of
// the vector distance calculation.
// If the field already exists in the document, its value get overwritten with the distance calculation.
// Otherwise, a new field gets added to the document.
DistanceResultField string
}

// VectorQuery represents a query that uses [Query.FindNearest] or [Query.FindNearestPath].
Expand Down Expand Up @@ -582,7 +606,7 @@ func (q Query) FindNearestPath(vectorFieldPath FieldPath, queryVector any, limit
case []float64:
fnvq = vectorToProtoValue(v)
default:
vq.q.err = errors.New("firestore: queryVector must be Vector32 or Vector64")
vq.q.err = errInvalidVector
return vq
}

Expand All @@ -592,6 +616,13 @@ func (q Query) FindNearestPath(vectorFieldPath FieldPath, queryVector any, limit
Limit: &wrapperspb.Int32Value{Value: trunc32(limit)},
DistanceMeasure: pb.StructuredQuery_FindNearest_DistanceMeasure(measure),
}

if options != nil {
if options.DistanceThreshold != nil {
vq.q.findNearest.DistanceThreshold = &wrapperspb.DoubleValue{Value: *options.DistanceThreshold}
}
vq.q.findNearest.DistanceResultField = *&options.DistanceResultField
}
return vq
}

Expand Down
Loading

0 comments on commit e9a551e

Please sign in to comment.