Skip to content

Commit

Permalink
genai: pass cached content in request (#149)
Browse files Browse the repository at this point in the history
We weren't passing the cached content name in generation or CountTokens
requests. Fix that.

Also support cached-content token count for CountTokens requests
(not yet implemented in the service).
  • Loading branch information
jba authored Jun 26, 2024
1 parent 147c8a5 commit 205ce44
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 10 deletions.
40 changes: 30 additions & 10 deletions genai/caching_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ func testCaching(t *testing.T, client *Client) {
t.Fatal("want error, got nil")
}
})
t.Run("generation", func(t *testing.T) {
t.Run("use", func(t *testing.T) {
txt := strings.Repeat("George Washington was the first president of the United States. ", 3000)
argcc := &CachedContent{
Model: model,
Expand All @@ -166,16 +166,36 @@ func testCaching(t *testing.T, client *Client) {
t.Fatal(err)
}
defer client.DeleteCachedContent(ctx, cc.Name)
tokenCount := cc.UsageMetadata.TotalTokenCount
m := client.GenerativeModelFromCachedContent(cc)
res, err := m.GenerateContent(ctx, Text("Who was the first US president?"))
if err != nil {
t.Fatal(err)
}
got := responseString(res)
const want = "Washington"
if !strings.Contains(got, want) {
t.Errorf("got %q, want string containing %q", got, want)
}
t.Run("generation", func(t *testing.T) {
res, err := m.GenerateContent(ctx, Text("Who was the first US president?"))
if err != nil {
t.Fatal(err)
}
got := responseString(res)
const want = "Washington"
if !strings.Contains(got, want) {
t.Errorf("got %q, want string containing %q", got, want)
}
if g, w := res.UsageMetadata.CachedContentTokenCount, tokenCount; g != w {
t.Errorf("CachedContentTokenCount: got %d, want %d", g, w)
}
})
t.Run("count", func(t *testing.T) {
t.Skip("not yet implemented")
gotRes, err := m.CountTokens(ctx, Text("Who Was the first US president?"))
if err != nil {
t.Fatal(err)
}
wantRes := &CountTokensResponse{
TotalTokens: 8,
CachedContentTokenCount: tokenCount,
}
if !cmp.Equal(gotRes, wantRes) {
t.Errorf("got %+v, want %+v", gotRes, wantRes)
}
})
})
}

Expand Down
5 changes: 5 additions & 0 deletions genai/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,10 @@ func (m *GenerativeModel) generateContent(ctx context.Context, req *pb.GenerateC

func (m *GenerativeModel) newGenerateContentRequest(contents ...*Content) (*pb.GenerateContentRequest, error) {
return pvCatchPanic(func() *pb.GenerateContentRequest {
var cc *string
if m.CachedContentName != "" {
cc = &m.CachedContentName
}
return &pb.GenerateContentRequest{
Model: m.fullName,
Contents: transformSlice(contents, (*Content).toProto),
Expand All @@ -205,6 +209,7 @@ func (m *GenerativeModel) newGenerateContentRequest(contents ...*Content) (*pb.G
ToolConfig: m.ToolConfig.toProto(),
GenerationConfig: m.GenerationConfig.toProto(),
SystemInstruction: m.SystemInstruction.toProto(),
CachedContent: cc,
}
})
}
Expand Down

0 comments on commit 205ce44

Please sign in to comment.