Skip to content

Commit

Permalink
genai: support the GetModel RPC (#40)
Browse files Browse the repository at this point in the history
Add GenerativeModel.Info and EmbeddedModel.Info, which both call the
GetModel RPC.
  • Loading branch information
jba authored Feb 6, 2024
1 parent 9779dcc commit fffe17c
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 0 deletions.
14 changes: 14 additions & 0 deletions genai/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,20 @@ func (m *GenerativeModel) newCountTokensRequest(contents ...*Content) *pb.CountT
}
}

// Info returns information about the model.
func (m *GenerativeModel) Info(ctx context.Context) (*ModelInfo, error) {
return m.c.modelInfo(ctx, m.fullName)
}

func (c *Client) modelInfo(ctx context.Context, fullName string) (*ModelInfo, error) {
req := &pb.GetModelRequest{Name: fullName}
res, err := c.mc.GetModel(ctx, req)
if err != nil {
return nil, err
}
return (ModelInfo{}).fromProto(res), nil
}

// A BlockedError indicates that the model's response was blocked.
// There can be two underlying causes: the prompt or a candidate response.
type BlockedError struct {
Expand Down
19 changes: 19 additions & 0 deletions genai/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,25 @@ func TestLive(t *testing.T) {
}
}
})
t.Run("get-model", func(t *testing.T) {
modelName := "gemini-pro"
got, err := client.GenerativeModel(modelName).Info(ctx)
if err != nil {
t.Fatal(err)
}
if w := "models/" + modelName; got.Name != w {
t.Errorf("got name %q, want %q", got.Name, w)
}

modelName = "embedding-001"
got, err = client.EmbeddingModel(modelName).Info(ctx)
if err != nil {
t.Fatal(err)
}
if w := "models/" + modelName; got.Name != w {
t.Errorf("got name %q, want %q", got.Name, w)
}
})
}

func TestJoinResponses(t *testing.T) {
Expand Down
5 changes: 5 additions & 0 deletions genai/embed.go
Original file line number Diff line number Diff line change
Expand Up @@ -118,3 +118,8 @@ func (m *EmbeddingModel) BatchEmbedContents(ctx context.Context, b *EmbeddingBat
}
return (BatchEmbedContentsResponse{}).fromProto(res), nil
}

// Info returns information about the model.
func (m *EmbeddingModel) Info(ctx context.Context) (*ModelInfo, error) {
return m.c.modelInfo(ctx, m.fullName)
}

0 comments on commit fffe17c

Please sign in to comment.