Skip to content

Commit

Permalink
genai: use new model names (#44)
Browse files Browse the repository at this point in the history
Update tests and examples to use the new preferred names for models.
  • Loading branch information
jba authored Feb 17, 2024
1 parent 03791f5 commit 223b9f6
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 17 deletions.
2 changes: 1 addition & 1 deletion CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ Guidelines](https://opensource.google/conduct/).
2. Run tests with `go test ./...`
3. You may need to run "live" tests that talk to a real endpoint; to do so, run
`go test -v ./genai/...` passing it your API key with the `-apikey` flag
and a model name flag like `-model gemini-pro`
and a model name flag like `-model gemini-1.0-pro`

### Code Reviews

Expand Down
2 changes: 1 addition & 1 deletion genai/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ type GenerativeModel struct {
}

// GenerativeModel creates a new instance of the named generative model.
// For instance, "gemini-pro" or "models/gemini-pro".
// For instance, "gemini-1.0-pro" or "models/gemini-1.0-pro".
//
// To access a tuned model named NAME, pass "tunedModels/NAME".
func (c *Client) GenerativeModel(name string) *GenerativeModel {
Expand Down
16 changes: 8 additions & 8 deletions genai/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ func TestLive(t *testing.T) {
})

t.Run("image", func(t *testing.T) {
vmodel := client.GenerativeModel(*modelName + "-vision")
vmodel := client.GenerativeModel(*modelName + "-vision-latest")
vmodel.Temperature = Ptr[float32](0)

data, err := os.ReadFile(filepath.Join("testdata", imageFile))
Expand Down Expand Up @@ -262,7 +262,7 @@ func TestLive(t *testing.T) {
got = append(got, m)
}

for _, name := range []string{"gemini-pro", "embedding-001"} {
for _, name := range []string{"gemini-1.0-pro", "embedding-001"} {
has := false
fullName := "models/" + name
for _, m := range got {
Expand All @@ -277,21 +277,21 @@ func TestLive(t *testing.T) {
}
})
t.Run("get-model", func(t *testing.T) {
modelName := "gemini-pro"
got, err := client.GenerativeModel(modelName).Info(ctx)
modName := *modelName
got, err := client.GenerativeModel(modName).Info(ctx)
if err != nil {
t.Fatal(err)
}
if w := "models/" + modelName; got.Name != w {
if w := "models/" + modName; got.Name != w {
t.Errorf("got name %q, want %q", got.Name, w)
}

modelName = "embedding-001"
got, err = client.EmbeddingModel(modelName).Info(ctx)
modName = "embedding-001"
got, err = client.EmbeddingModel(modName).Info(ctx)
if err != nil {
t.Fatal(err)
}
if w := "models/" + modelName; got.Name != w {
if w := "models/" + modName; got.Name != w {
t.Errorf("got name %q, want %q", got.Name, w)
}
})
Expand Down
14 changes: 7 additions & 7 deletions genai/example_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ func ExampleGenerativeModel_GenerateContent() {
}
defer client.Close()

model := client.GenerativeModel("gemini-pro")
model := client.GenerativeModel("gemini-1.0-pro")
resp, err := model.GenerateContent(ctx, genai.Text("What is the average size of a swallow?"))
if err != nil {
log.Fatal(err)
Expand All @@ -54,7 +54,7 @@ func ExampleGenerativeModel_GenerateContent_config() {
}
defer client.Close()

model := client.GenerativeModel("gemini-pro")
model := client.GenerativeModel("gemini-1.0-pro")
model.SetTemperature(0.9)
model.SetTopP(0.5)
model.SetTopK(20)
Expand All @@ -76,7 +76,7 @@ func ExampleGenerativeModel_GenerateContent_safetySetting() {
}
defer client.Close()

model := client.GenerativeModel("gemini-pro")
model := client.GenerativeModel("gemini-1.0-pro")
model.SafetySettings = []*genai.SafetySetting{
{
Category: genai.HarmCategoryDangerousContent,
Expand All @@ -102,7 +102,7 @@ func ExampleGenerativeModel_GenerateContentStream() {
}
defer client.Close()

model := client.GenerativeModel("gemini-pro")
model := client.GenerativeModel("gemini-1.0-pro")

iter := model.GenerateContentStream(ctx, genai.Text("Tell me a story about a lumberjack and his giant ox. Keep it very short."))
for {
Expand All @@ -125,7 +125,7 @@ func ExampleGenerativeModel_CountTokens() {
}
defer client.Close()

model := client.GenerativeModel("gemini-pro")
model := client.GenerativeModel("gemini-1.0-pro")

resp, err := model.CountTokens(ctx, genai.Text("What kind of fish is this?"))
if err != nil {
Expand All @@ -142,7 +142,7 @@ func ExampleChatSession() {
log.Fatal(err)
}
defer client.Close()
model := client.GenerativeModel("gemini-pro")
model := client.GenerativeModel("gemini-1.0-pro")
cs := model.StartChat()

send := func(msg string) *genai.GenerateContentResponse {
Expand Down Expand Up @@ -222,7 +222,7 @@ func ExampleGenerativeModel_GenerateContentStream_errors() {
log.Fatal(err)
}

model := client.GenerativeModel("gemini-pro")
model := client.GenerativeModel("gemini-1.0-pro")

iter := model.GenerateContentStream(ctx, genai.ImageData("foo", []byte("bar")))
res, err := iter.Next()
Expand Down

0 comments on commit 223b9f6

Please sign in to comment.