Skip to content

Commit

Permalink
feat: Batch Text Embeddings Sample
Browse files Browse the repository at this point in the history
  • Loading branch information
BigBlackWolf committed Oct 2, 2024
1 parent 6e7b24c commit d816a79
Show file tree
Hide file tree
Showing 2 changed files with 129 additions and 0 deletions.
79 changes: 79 additions & 0 deletions aiplatform/snippets/embedding_batch_predict.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
// Copyright 2024 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package snippets

// [START generativeaionvertexai_sdk_embedding_batch]
import (
"context"
"fmt"
"io"

aiplatform "cloud.google.com/go/aiplatform/apiv1"
aiplatformpb "cloud.google.com/go/aiplatform/apiv1/aiplatformpb"
"google.golang.org/api/option"
)

func embedBatchPredict(w io.Writer, projectID, location, name, outputURI string, inputURIs []string) error {
// inputURI := []string{"gs://cloud-samples-data/generative-ai/embeddings/embeddings_input.jsonl"}
// outputURI: existing template path. Following formats are allowed:
// - gs://BUCKET_NAME/DIRECTORY/
// - bq://project_name.llm_dataset

ctx := context.Background()
apiEndpoint := fmt.Sprintf("%s-aiplatform.googleapis.com:443", location)
// Pretrained model
model := "publishers/google/models/textembedding-gecko"

client, err := aiplatform.NewJobClient(ctx, option.WithEndpoint(apiEndpoint))
if err != nil {
return err
}

req := &aiplatformpb.CreateBatchPredictionJobRequest{
Parent: fmt.Sprintf("projects/%s/locations/%s", projectID, location),
BatchPredictionJob: &aiplatformpb.BatchPredictionJob{
DisplayName: name,
Model: model,
InputConfig: &aiplatformpb.BatchPredictionJob_InputConfig{
Source: &aiplatformpb.BatchPredictionJob_InputConfig_GcsSource{
GcsSource: &aiplatformpb.GcsSource{
Uris: inputURIs,
},
},
// List of supported formarts: https://cloud.google.com/vertex-ai/docs/reference/rpc/google.cloud.aiplatform.v1#model
InstancesFormat: "jsonl",
},
OutputConfig: &aiplatformpb.BatchPredictionJob_OutputConfig{
Destination: &aiplatformpb.BatchPredictionJob_OutputConfig_GcsDestination{
GcsDestination: &aiplatformpb.GcsDestination{
OutputUriPrefix: outputURI,
},
},
// List of supported formarts: https://cloud.google.com/vertex-ai/docs/reference/rpc/google.cloud.aiplatform.v1#model
PredictionsFormat: "jsonl",
},
},
}

job, err := client.CreateBatchPredictionJob(ctx, req)
if err != nil {
return err
}
fmt.Fprint(w, job.GetDisplayName())

return nil
}

// [END generativeaionvertexai_sdk_embedding_batch]
50 changes: 50 additions & 0 deletions aiplatform/snippets/embedding_batch_predict_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
// Copyright 2024 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package snippets

import (
"bytes"
"context"
"fmt"
"math/rand"
"testing"
"time"

"github.com/GoogleCloudPlatform/golang-samples/internal/testutil"
)

func TestBatchPredict(t *testing.T) {
tc := testutil.SystemTest(t)
var buf bytes.Buffer
var r *rand.Rand = rand.New(
rand.NewSource(time.Now().UnixNano()))

ctx := context.Background()
bucketName := testutil.TestBucket(ctx, t, tc.ProjectID, "golang-samples-batch")
location := "us-central1"
outputURI := fmt.Sprintf("gs://%s/", bucketName)
inputURIs := []string{"gs://cloud-samples-data/generative-ai/embeddings/embeddings_input.jsonl"}
name := fmt.Sprintf("test-job-go-batch-%v-%v", time.Now().Format("2006-01-02"), r.Int())

err := embedBatchPredict(&buf, tc.ProjectID, location, name, outputURI, inputURIs)
if err != nil {
t.Error(err)
}

output := buf.String()
if output != name {
t.Errorf("job name doesn't match. Got: %s, want: %s", output, name)
}
}

0 comments on commit d816a79

Please sign in to comment.