Skip to content

Commit

Permalink
add Client.ListModels
Browse files Browse the repository at this point in the history
  • Loading branch information
jba authored and eliben committed Dec 13, 2023
1 parent 177f6cf commit 5563623
Show file tree
Hide file tree
Showing 5 changed files with 197 additions and 5 deletions.
9 changes: 7 additions & 2 deletions genai/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@ import (

// A Client is a Google generative AI client.
type Client struct {
c *gl.GenerativeClient
c *gl.GenerativeClient
mc *gl.ModelClient
}

// NewClient creates a new Google generative AI client.
Expand All @@ -48,7 +49,11 @@ func NewClient(ctx context.Context, opts ...option.ClientOption) (*Client, error
if err != nil {
return nil, err
}
return &Client{c: c}, nil
mc, err := gl.NewModelRESTClient(ctx, opts...)
if err != nil {
return nil, err
}
return &Client{c: c, mc: mc}, nil
}

// Close closes the client.
Expand Down
34 changes: 31 additions & 3 deletions genai/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ func TestLive(t *testing.T) {
t.Run("streaming", func(t *testing.T) {
iter := model.GenerateContentStream(ctx, Text("Are you hungry?"))
got := responsesString(t, iter)
checkMatch(t, got, `(don't|do\s+not) (have|possess) .*(a .* body|the ability)`)
checkMatch(t, got, `(don't|do\s+not) (have|possess) .*(a .* needs|body|the ability)`)
})

t.Run("chat", func(t *testing.T) {
Expand Down Expand Up @@ -231,7 +231,7 @@ func TestLive(t *testing.T) {
if g, w := funcall.Name, weatherTool.FunctionDeclarations[0].Name; g != w {
t.Errorf("FunctionCall.Name: got %q, want %q", g, w)
}
if g, c := funcall.Args["location"], "New York"; !strings.Contains(g.(string), c) {
if g, c := funcall.Args["location"], "New York"; g != nil && !strings.Contains(g.(string), c) {
t.Errorf(`FunctionCall.Args["location"]: got %q, want string containing %q`, g, c)
}
res, err = session.SendMessage(ctx, FunctionResponse{
Expand All @@ -243,7 +243,7 @@ func TestLive(t *testing.T) {
if err != nil {
t.Fatal(err)
}
checkMatch(t, responseString(res), "(it's}|weather) .*cold")
checkMatch(t, responseString(res), "(it's|it is|weather) .*cold")
})
t.Run("embed", func(t *testing.T) {
em := client.EmbeddingModel("embedding-001")
Expand All @@ -263,6 +263,34 @@ func TestLive(t *testing.T) {
t.Errorf("bad result: %v\n", res)
}
})
t.Run("list-models", func(t *testing.T) {
iter := client.ListModels(ctx)
var got []*Model
for {
m, err := iter.Next()
if err == iterator.Done {
break
}
if err != nil {
t.Fatal(err)
}
got = append(got, m)
}

for _, name := range []string{"gemini-pro", "embedding-001"} {
has := false
fullName := "models/" + name
for _, m := range got {
if m.Name == fullName {
has = true
break
}
}
if !has {
t.Errorf("missing model %q", name)
}
}
})
}

func TestJoinResponses(t *testing.T) {
Expand Down
10 changes: 10 additions & 0 deletions genai/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,16 @@ types:

ContentEmbedding:

Model:
fields:
BaseModelId:
name: BaseModeID
Temperature:
type: float32
TopP:
type: float32
TopK:
type: int32



Expand Down
99 changes: 99 additions & 0 deletions genai/generativelanguagepb_veneer.gen.go
Original file line number Diff line number Diff line change
Expand Up @@ -650,6 +650,105 @@ func (v HarmProbability) String() string {
return fmt.Sprintf("HarmProbability(%d)", v)
}

// Model is information about a Generative Language Model.
type Model struct {
// Required. The resource name of the `Model`.
//
// Format: `models/{model}` with a `{model}` naming convention of:
//
// * "{base_model_id}-{version}"
//
// Examples:
//
// * `models/chat-bison-001`
Name string
// Required. The name of the base model, pass this to the generation request.
//
// Examples:
//
// * `chat-bison`
BaseModeID string
// Required. The version number of the model.
//
// This represents the major version
Version string
// The human-readable name of the model. E.g. "Chat Bison".
//
// The name can be up to 128 characters long and can consist of any UTF-8
// characters.
DisplayName string
// A short description of the model.
Description string
// Maximum number of input tokens allowed for this model.
InputTokenLimit int32
// Maximum number of output tokens available for this model.
OutputTokenLimit int32
// The model's supported generation methods.
//
// The method names are defined as Pascal case
// strings, such as `generateMessage` which correspond to API methods.
SupportedGenerationMethods []string
// Controls the randomness of the output.
//
// Values can range over `[0.0,1.0]`, inclusive. A value closer to `1.0` will
// produce responses that are more varied, while a value closer to `0.0` will
// typically result in less surprising responses from the model.
// This value specifies default to be used by the backend while making the
// call to the model.
Temperature float32
// For Nucleus sampling.
//
// Nucleus sampling considers the smallest set of tokens whose probability
// sum is at least `top_p`.
// This value specifies default to be used by the backend while making the
// call to the model.
TopP float32
// For Top-k sampling.
//
// Top-k sampling considers the set of `top_k` most probable tokens.
// This value specifies default to be used by the backend while making the
// call to the model.
TopK int32
}

func (v *Model) toProto() *pb.Model {
if v == nil {
return nil
}
return &pb.Model{
Name: v.Name,
BaseModelId: v.BaseModeID,
Version: v.Version,
DisplayName: v.DisplayName,
Description: v.Description,
InputTokenLimit: v.InputTokenLimit,
OutputTokenLimit: v.OutputTokenLimit,
SupportedGenerationMethods: v.SupportedGenerationMethods,
Temperature: support.AddrOrNil(v.Temperature),
TopP: support.AddrOrNil(v.TopP),
TopK: support.AddrOrNil(v.TopK),
}
}

func (Model) fromProto(p *pb.Model) *Model {
if p == nil {
return nil
}
return &Model{
Name: p.Name,
BaseModeID: p.BaseModelId,
Version: p.Version,
DisplayName: p.DisplayName,
Description: p.Description,
InputTokenLimit: p.InputTokenLimit,
OutputTokenLimit: p.OutputTokenLimit,
SupportedGenerationMethods: p.SupportedGenerationMethods,
Temperature: support.DerefOrZero(p.Temperature),
TopP: support.DerefOrZero(p.TopP),
TopK: support.DerefOrZero(p.TopK),
}
}

// PromptFeedback contains a set of the feedback metadata the prompt specified in
// `GenerateContentRequest.content`.
type PromptFeedback struct {
Expand Down
50 changes: 50 additions & 0 deletions genai/list_models.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
// Copyright 2023 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package genai

import (
"context"

gl "cloud.google.com/go/ai/generativelanguage/apiv1beta"
pb "cloud.google.com/go/ai/generativelanguage/apiv1beta/generativelanguagepb"

"google.golang.org/api/iterator"
)

func (c *Client) ListModels(ctx context.Context) *ModelIterator {
return &ModelIterator{
it: c.mc.ListModels(ctx, &pb.ListModelsRequest{}),
}
}

// A ModelIterator iterates over Models.
type ModelIterator struct {
it *gl.ModelIterator
}

// Next returns the next result. Its second return value is iterator.Done if there are no more
// results. Once Next returns Done, all subsequent calls will return Done.
func (it *ModelIterator) Next() (*Model, error) {
m, err := it.it.Next()
if err != nil {
return nil, err
}
return (Model{}).fromProto(m), nil
}

// PageInfo supports pagination. See the google.golang.org/api/iterator package for details.
func (it *ModelIterator) PageInfo() *iterator.PageInfo {
return it.it.PageInfo()
}

0 comments on commit 5563623

Please sign in to comment.