Skip to content

Commit

Permalink
feat(firestore): Distance result field and distance threshold in vect…
Browse files Browse the repository at this point in the history
…or search (#4362)

* feat(firestore): Vector search

* test(firestore): Clean up test resources

* refactor(firestore): Refactoring tests

* feat(firestore): Distance result field and distance threshold in vector search

* feat(firestore): Updating branch

* feat(firestore): Add link to documentation

---------

Co-authored-by: Eric Schmidt <[email protected]>
  • Loading branch information
bhshkh and telpirion authored Sep 17, 2024
1 parent c4ce89c commit 47eea43
Show file tree
Hide file tree
Showing 7 changed files with 307 additions and 0 deletions.
1 change: 1 addition & 0 deletions firestore/vector_search_basic.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ func vectorSearchBasic(w io.Writer, projectID string) error {
collection := client.Collection("coffee-beans")

// Requires a vector index
// https://firebase.google.com/docs/firestore/vector-search#create_and_manage_vector_indexes
vectorQuery := collection.FindNearest("embedding_field",
[]float32{3.0, 1.0, 2.0},
5,
Expand Down
59 changes: 59 additions & 0 deletions firestore/vector_search_distance_threshold.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
// Copyright 2024 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package firestore

// [START firestore_vector_search_basic]
import (
"context"
"fmt"
"io"

"cloud.google.com/go/firestore"
)

func vectorSearchDistanceThreshold(w io.Writer, projectID string) error {
ctx := context.Background()

client, err := firestore.NewClient(ctx, projectID)
if err != nil {
return fmt.Errorf("firestore.NewClient: %w", err)
}
defer client.Close()

collection := client.Collection("coffee-beans")

// Requires a vector index
// https://firebase.google.com/docs/firestore/vector-search#create_and_manage_vector_indexes
vectorQuery := collection.FindNearest("embedding_field",
[]float32{3.0, 1.0, 2.0},
10,
firestore.DistanceMeasureEuclidean,
&firestore.FindNearestOptions{
DistanceThreshold: firestore.Ptr[float64](4.5),
})

docs, err := vectorQuery.Documents(ctx).GetAll()
if err != nil {
fmt.Fprintf(w, "failed to get vector query results: %v", err)
return err
}

for _, doc := range docs {
fmt.Fprintln(w, doc.Data()["name"])
}
return nil
}

// [END firestore_vector_search_basic]
42 changes: 42 additions & 0 deletions firestore/vector_search_distance_threshold_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
// Copyright 2024 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package firestore

import (
"bytes"
"os"
"strings"
"testing"
)

func TestVectorSearchDistanceThreshold(t *testing.T) {
projectID := os.Getenv("GOLANG_SAMPLES_FIRESTORE_PROJECT")
if projectID == "" {
t.Skip("Skipping firestore test. Set GOLANG_SAMPLES_FIRESTORE_PROJECT.")
}

buf := new(bytes.Buffer)
if err := vectorSearchDistanceThreshold(buf, projectID); err != nil {
t.Errorf("vectorSearchDistanceThreshold: %v", err)
}

// Compare console outputs
got := buf.String()
want := "Sleepy coffee beans\n" +
"Kahawa coffee beans\n"
if !strings.Contains(got, want) {
t.Errorf("got %q, want %q", got, want)
}
}
59 changes: 59 additions & 0 deletions firestore/vector_search_result_field.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
// Copyright 2024 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package firestore

// [START firestore_vector_search_distance_result_field]
import (
"context"
"fmt"
"io"

"cloud.google.com/go/firestore"
)

func vectorSearchDistanceResultField(w io.Writer, projectID string) error {
ctx := context.Background()

client, err := firestore.NewClient(ctx, projectID)
if err != nil {
return fmt.Errorf("firestore.NewClient: %w", err)
}
defer client.Close()

collection := client.Collection("coffee-beans")

// Requires a vector index
// https://firebase.google.com/docs/firestore/vector-search#create_and_manage_vector_indexes
vectorQuery := collection.FindNearest("embedding_field",
[]float32{3.0, 1.0, 2.0},
10,
firestore.DistanceMeasureEuclidean,
&firestore.FindNearestOptions{
DistanceResultField: "vector_distance",
})

docs, err := vectorQuery.Documents(ctx).GetAll()
if err != nil {
fmt.Fprintf(w, "failed to get vector query results: %v", err)
return err
}

for _, doc := range docs {
fmt.Fprintf(w, "%v, Distance: %v\n", doc.Data()["name"], doc.Data()["vector_distance"])
}
return nil
}

// [END firestore_vector_search_distance_result_field]
60 changes: 60 additions & 0 deletions firestore/vector_search_result_field_masked.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
// Copyright 2024 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package firestore

// [START firestore_vector_search_distance_result_field]
import (
"context"
"fmt"
"io"

"cloud.google.com/go/firestore"
)

func vectorSearchDistanceResultFieldMasked(w io.Writer, projectID string) error {
ctx := context.Background()

client, err := firestore.NewClient(ctx, projectID)
if err != nil {
return fmt.Errorf("firestore.NewClient: %w", err)
}
defer client.Close()

collection := client.Collection("coffee-beans")

// Requires a vector index
// https://firebase.google.com/docs/firestore/vector-search#create_and_manage_vector_indexes
vectorQuery := collection.Select("color", "vector_distance").
FindNearest("embedding_field",
[]float32{3.0, 1.0, 2.0},
10,
firestore.DistanceMeasureEuclidean,
&firestore.FindNearestOptions{
DistanceResultField: "vector_distance",
})

docs, err := vectorQuery.Documents(ctx).GetAll()
if err != nil {
fmt.Fprintf(w, "failed to get vector query results: %v", err)
return err
}

for _, doc := range docs {
fmt.Fprintf(w, "%v, Distance: %v\n", doc.Data()["color"], doc.Data()["vector_distance"])
}
return nil
}

// [END firestore_vector_search_distance_result_field]
43 changes: 43 additions & 0 deletions firestore/vector_search_result_field_masked_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
// Copyright 2024 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package firestore

import (
"bytes"
"os"
"strings"
"testing"
)

func TestVectorSearchDistanceResultFieldMasked(t *testing.T) {
projectID := os.Getenv("GOLANG_SAMPLES_FIRESTORE_PROJECT")
if projectID == "" {
t.Skip("Skipping firestore test. Set GOLANG_SAMPLES_FIRESTORE_PROJECT.")
}

buf := new(bytes.Buffer)
if err := vectorSearchDistanceResultFieldMasked(buf, projectID); err != nil {
t.Errorf("vectorSearchDistanceResultFieldMasked: %v", err)
}

// Compare console outputs
got := buf.String()
want := "red, Distance: 0\n" +
"red, Distance: 2.449489742783178\n" +
"brown, Distance: 5.744562646538029\n"
if !strings.Contains(got, want) {
t.Errorf("got %q, want %q", got, want)
}
}
43 changes: 43 additions & 0 deletions firestore/vector_search_result_field_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
// Copyright 2024 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package firestore

import (
"bytes"
"os"
"strings"
"testing"
)

func TestVectorSearchDistanceResultField(t *testing.T) {
projectID := os.Getenv("GOLANG_SAMPLES_FIRESTORE_PROJECT")
if projectID == "" {
t.Skip("Skipping firestore test. Set GOLANG_SAMPLES_FIRESTORE_PROJECT.")
}

buf := new(bytes.Buffer)
if err := vectorSearchDistanceResultField(buf, projectID); err != nil {
t.Errorf("vectorSearchDistanceResultField: %v", err)
}

// Compare console outputs
got := buf.String()
want := "Sleepy coffee beans, Distance: 0\n" +
"Kahawa coffee beans, Distance: 2.449489742783178\n" +
"Owl coffee beans, Distance: 5.744562646538029\n"
if !strings.Contains(got, want) {
t.Errorf("got %q, want %q", got, want)
}
}

0 comments on commit 47eea43

Please sign in to comment.