Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

genai: add test and example for JSON schema #123

Merged
merged 3 commits into from
May 29, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 39 additions & 0 deletions genai/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -454,11 +454,50 @@ func TestLive(t *testing.T) {
}
got := responseString(res)
t.Logf("got %s", got)
// Accept any valid JSON
var a any
if err := json.Unmarshal([]byte(got), &a); err != nil {
t.Fatal(err)
}
})
t.Run("JSON schema", func(t *testing.T) {
model := client.GenerativeModel("gemini-1.5-pro-latest")
model.SetTemperature(0)
model.ResponseMIMEType = "application/json"
model.ResponseSchema = &Schema{
Type: TypeArray,
Items: &Schema{
Type: TypeObject,
Properties: map[string]*Schema{
"name": {
Type: TypeString,
Description: "The name of the color",
},
"RGB": {
Type: TypeString,
Description: "The RGB value of the color, in hex",
},
},
Required: []string{"name", "RGB"},
},
}
res, err := model.GenerateContent(ctx, Text("List the primary colors."))
if err != nil {
t.Fatal(err)
}
got := responseString(res)
t.Logf("got %s", got)
// Check that the format of the result matches the schema. The actual content
// doesn't matter here.

type color struct {
Name, RGB string
}
var v []color
if err := json.Unmarshal([]byte(got), &v); err != nil {
t.Fatal(err)
}
})
}

func TestJoinResponses(t *testing.T) {
Expand Down
32 changes: 31 additions & 1 deletion genai/example_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,14 @@

import (
"context"
"encoding/json"
"errors"
"fmt"
"log"
"net/http"
"net/url"
"os"
"slices"

Check failure on line 26 in genai/example_test.go

View workflow job for this annotation

GitHub Actions / build

"slices" imported and not used

"github.com/google/generative-ai-go/genai"
"google.golang.org/api/googleapi"
Expand Down Expand Up @@ -141,6 +143,34 @@
fmt.Println("Num tokens:", resp.TotalTokens)
}

func ExampleGenerativeModel_JSONSchema() {
ctx := context.Background()
client, err := genai.NewClient(ctx, option.WithAPIKey(os.Getenv("GEMINI_API_KEY")))
if err != nil {
log.Fatal(err)
}
defer client.Close()

model := client.GenerativeModel("gemini-1.5-pro")
jba marked this conversation as resolved.
Show resolved Hide resolved
model.ResponseMIMEType = "application/json"
model.ResponseSchema = &genai.Schema{
jba marked this conversation as resolved.
Show resolved Hide resolved
Type: genai.TypeArray,
Items: &genai.Schema{Type: genai.TypeString},
}
res, err := model.GenerateContent(ctx, genai.Text("List the primary colors."))
if err != nil {
log.Fatal(err)
}
for _, part := range res.Candidates[0].Content.Parts {
if txt, ok := part.(genai.Text); ok {
var colors []string
if err := json.Unmarshal([]byte(txt), &colors); err != nil {
log.Fatal(err)
}
fmt.Println(colors)
}
}
}
func ExampleChatSession() {
ctx := context.Background()
client, err := genai.NewClient(ctx, option.WithAPIKey(os.Getenv("GEMINI_API_KEY")))
Expand Down Expand Up @@ -399,7 +429,7 @@
printResponse(res)
}

func ExampleToolConifg() {
func ExampleToolConfig() {
// This example shows how to affect how the model uses the tools provided to it.
// By setting the ToolConfig, you can disable function calling.

Expand Down
Loading