diff --git a/genai/client_test.go b/genai/client_test.go index 555666e..af7032c 100644 --- a/genai/client_test.go +++ b/genai/client_test.go @@ -441,7 +441,7 @@ func TestLive(t *testing.T) { if err != nil { t.Fatal(err) } - checkMatch(t, responseString(resp), "picture|image", "person", "computer|laptop") + checkMatch(t, responseString(resp), "person", "computer|laptop") }) t.Run("JSON", func(t *testing.T) { @@ -454,11 +454,50 @@ func TestLive(t *testing.T) { } got := responseString(res) t.Logf("got %s", got) + // Accept any valid JSON var a any if err := json.Unmarshal([]byte(got), &a); err != nil { t.Fatal(err) } }) + t.Run("JSON schema", func(t *testing.T) { + model := client.GenerativeModel("gemini-1.5-pro-latest") + model.SetTemperature(0) + model.ResponseMIMEType = "application/json" + model.ResponseSchema = &Schema{ + Type: TypeArray, + Items: &Schema{ + Type: TypeObject, + Properties: map[string]*Schema{ + "name": { + Type: TypeString, + Description: "The name of the color", + }, + "RGB": { + Type: TypeString, + Description: "The RGB value of the color, in hex", + }, + }, + Required: []string{"name", "RGB"}, + }, + } + res, err := model.GenerateContent(ctx, Text("List the primary colors.")) + if err != nil { + t.Fatal(err) + } + got := responseString(res) + t.Logf("got %s", got) + // Check that the format of the result matches the schema. The actual content + // doesn't matter here. + + type color struct { + Name, RGB string + } + var v []color + if err := json.Unmarshal([]byte(got), &v); err != nil { + t.Fatal(err) + } + }) } func TestJoinResponses(t *testing.T) { diff --git a/genai/example_test.go b/genai/example_test.go index 28f8551..99c90fc 100644 --- a/genai/example_test.go +++ b/genai/example_test.go @@ -16,6 +16,7 @@ package genai_test import ( "context" + "encoding/json" "errors" "fmt" "log" @@ -141,6 +142,37 @@ func ExampleGenerativeModel_CountTokens() { fmt.Println("Num tokens:", resp.TotalTokens) } +// This example shows how to get a JSON response that conforms to a schema. +func ExampleGenerativeModel_JSONSchema() { + ctx := context.Background() + client, err := genai.NewClient(ctx, option.WithAPIKey(os.Getenv("GEMINI_API_KEY"))) + if err != nil { + log.Fatal(err) + } + defer client.Close() + + model := client.GenerativeModel("gemini-1.5-pro-latest") + // Ask the model to respond with JSON. + model.ResponseMIMEType = "application/json" + // Specify the format of the JSON. + model.ResponseSchema = &genai.Schema{ + Type: genai.TypeArray, + Items: &genai.Schema{Type: genai.TypeString}, + } + res, err := model.GenerateContent(ctx, genai.Text("List the primary colors.")) + if err != nil { + log.Fatal(err) + } + for _, part := range res.Candidates[0].Content.Parts { + if txt, ok := part.(genai.Text); ok { + var colors []string + if err := json.Unmarshal([]byte(txt), &colors); err != nil { + log.Fatal(err) + } + fmt.Println(colors) + } + } +} func ExampleChatSession() { ctx := context.Background() client, err := genai.NewClient(ctx, option.WithAPIKey(os.Getenv("GEMINI_API_KEY"))) @@ -399,7 +431,7 @@ func ExampleTool() { printResponse(res) } -func ExampleToolConifg() { +func ExampleToolConfig() { // This example shows how to affect how the model uses the tools provided to it. // By setting the ToolConfig, you can disable function calling.