Skip to content

Commit

Permalink
genai: add test and example for JSON schema (#123)
Browse files Browse the repository at this point in the history
  • Loading branch information
jba authored May 29, 2024
1 parent 665f6f7 commit de3b601
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 2 deletions.
41 changes: 40 additions & 1 deletion genai/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -441,7 +441,7 @@ func TestLive(t *testing.T) {
if err != nil {
t.Fatal(err)
}
checkMatch(t, responseString(resp), "picture|image", "person", "computer|laptop")
checkMatch(t, responseString(resp), "person", "computer|laptop")
})

t.Run("JSON", func(t *testing.T) {
Expand All @@ -454,11 +454,50 @@ func TestLive(t *testing.T) {
}
got := responseString(res)
t.Logf("got %s", got)
// Accept any valid JSON
var a any
if err := json.Unmarshal([]byte(got), &a); err != nil {
t.Fatal(err)
}
})
t.Run("JSON schema", func(t *testing.T) {
model := client.GenerativeModel("gemini-1.5-pro-latest")
model.SetTemperature(0)
model.ResponseMIMEType = "application/json"
model.ResponseSchema = &Schema{
Type: TypeArray,
Items: &Schema{
Type: TypeObject,
Properties: map[string]*Schema{
"name": {
Type: TypeString,
Description: "The name of the color",
},
"RGB": {
Type: TypeString,
Description: "The RGB value of the color, in hex",
},
},
Required: []string{"name", "RGB"},
},
}
res, err := model.GenerateContent(ctx, Text("List the primary colors."))
if err != nil {
t.Fatal(err)
}
got := responseString(res)
t.Logf("got %s", got)
// Check that the format of the result matches the schema. The actual content
// doesn't matter here.

type color struct {
Name, RGB string
}
var v []color
if err := json.Unmarshal([]byte(got), &v); err != nil {
t.Fatal(err)
}
})
}

func TestJoinResponses(t *testing.T) {
Expand Down
34 changes: 33 additions & 1 deletion genai/example_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ package genai_test

import (
"context"
"encoding/json"
"errors"
"fmt"
"log"
Expand Down Expand Up @@ -141,6 +142,37 @@ func ExampleGenerativeModel_CountTokens() {
fmt.Println("Num tokens:", resp.TotalTokens)
}

// This example shows how to get a JSON response that conforms to a schema.
func ExampleGenerativeModel_JSONSchema() {
ctx := context.Background()
client, err := genai.NewClient(ctx, option.WithAPIKey(os.Getenv("GEMINI_API_KEY")))
if err != nil {
log.Fatal(err)
}
defer client.Close()

model := client.GenerativeModel("gemini-1.5-pro-latest")
// Ask the model to respond with JSON.
model.ResponseMIMEType = "application/json"
// Specify the format of the JSON.
model.ResponseSchema = &genai.Schema{
Type: genai.TypeArray,
Items: &genai.Schema{Type: genai.TypeString},
}
res, err := model.GenerateContent(ctx, genai.Text("List the primary colors."))
if err != nil {
log.Fatal(err)
}
for _, part := range res.Candidates[0].Content.Parts {
if txt, ok := part.(genai.Text); ok {
var colors []string
if err := json.Unmarshal([]byte(txt), &colors); err != nil {
log.Fatal(err)
}
fmt.Println(colors)
}
}
}
func ExampleChatSession() {
ctx := context.Background()
client, err := genai.NewClient(ctx, option.WithAPIKey(os.Getenv("GEMINI_API_KEY")))
Expand Down Expand Up @@ -399,7 +431,7 @@ func ExampleTool() {
printResponse(res)
}

func ExampleToolConifg() {
func ExampleToolConfig() {
// This example shows how to affect how the model uses the tools provided to it.
// By setting the ToolConfig, you can disable function calling.

Expand Down

0 comments on commit de3b601

Please sign in to comment.