Skip to content

Commit

Permalink
genai: turn conversion panics to errors (#143)
Browse files Browse the repository at this point in the history
Return an error if a conversion fails, instead of panicking.

Currently the only possible such failure is when a FunctionResponse
contains a value that can't be converted to a Struct proto.
  • Loading branch information
jba authored Jun 25, 2024
1 parent 761f4d9 commit 5220ad2
Show file tree
Hide file tree
Showing 5 changed files with 3,341 additions and 3,280 deletions.
10 changes: 8 additions & 2 deletions genai/chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,10 @@ func (m *GenerativeModel) StartChat() *ChatSession {
func (cs *ChatSession) SendMessage(ctx context.Context, parts ...Part) (*GenerateContentResponse, error) {
// Call the underlying client with the entire history plus the argument Content.
cs.History = append(cs.History, newUserContent(parts))
req := cs.m.newGenerateContentRequest(cs.History...)
req, err := cs.m.newGenerateContentRequest(cs.History...)
if err != nil {
return nil, err
}
req.GenerationConfig.CandidateCount = Ptr[int32](1)
resp, err := cs.m.generateContent(ctx, req)
if err != nil {
Expand All @@ -46,7 +49,10 @@ func (cs *ChatSession) SendMessage(ctx context.Context, parts ...Part) (*Generat
// SendMessageStream is like SendMessage, but with a streaming request.
func (cs *ChatSession) SendMessageStream(ctx context.Context, parts ...Part) *GenerateContentResponseIterator {
cs.History = append(cs.History, newUserContent(parts))
req := cs.m.newGenerateContentRequest(cs.History...)
req, err := cs.m.newGenerateContentRequest(cs.History...)
if err != nil {
return &GenerateContentResponseIterator{err: err}
}
req.GenerationConfig.CandidateCount = Ptr[int32](1)
streamClient, err := cs.m.c.c.StreamGenerateContent(ctx, req)
return &GenerateContentResponseIterator{
Expand Down
68 changes: 45 additions & 23 deletions genai/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,10 @@ func fullModelName(name string) string {
// GenerateContent produces a single request and response.
func (m *GenerativeModel) GenerateContent(ctx context.Context, parts ...Part) (*GenerateContentResponse, error) {
content := newUserContent(parts)
req := m.newGenerateContentRequest(content)
req, err := m.newGenerateContentRequest(content)
if err != nil {
return nil, err
}
res, err := m.c.c.GenerateContent(ctx, req)
if err != nil {
return nil, err
Expand All @@ -168,11 +171,14 @@ func (m *GenerativeModel) GenerateContent(ctx context.Context, parts ...Part) (*

// GenerateContentStream returns an iterator that enumerates responses.
func (m *GenerativeModel) GenerateContentStream(ctx context.Context, parts ...Part) *GenerateContentResponseIterator {
streamClient, err := m.c.c.StreamGenerateContent(ctx, m.newGenerateContentRequest(newUserContent(parts)))
return &GenerateContentResponseIterator{
sc: streamClient,
err: err,
iter := &GenerateContentResponseIterator{}
req, err := m.newGenerateContentRequest(newUserContent(parts))
if err != nil {
iter.err = err
} else {
iter.sc, iter.err = m.c.c.StreamGenerateContent(ctx, req)
}
return iter
}

func (m *GenerativeModel) generateContent(ctx context.Context, req *pb.GenerateContentRequest) (*GenerateContentResponse, error) {
Expand All @@ -192,16 +198,18 @@ func (m *GenerativeModel) generateContent(ctx context.Context, req *pb.GenerateC
}
}

func (m *GenerativeModel) newGenerateContentRequest(contents ...*Content) *pb.GenerateContentRequest {
return &pb.GenerateContentRequest{
Model: m.fullName,
Contents: transformSlice(contents, (*Content).toProto),
SafetySettings: transformSlice(m.SafetySettings, (*SafetySetting).toProto),
Tools: transformSlice(m.Tools, (*Tool).toProto),
ToolConfig: m.ToolConfig.toProto(),
GenerationConfig: m.GenerationConfig.toProto(),
SystemInstruction: m.SystemInstruction.toProto(),
}
func (m *GenerativeModel) newGenerateContentRequest(contents ...*Content) (*pb.GenerateContentRequest, error) {
return pvCatchPanic(func() *pb.GenerateContentRequest {
return &pb.GenerateContentRequest{
Model: m.fullName,
Contents: transformSlice(contents, (*Content).toProto),
SafetySettings: transformSlice(m.SafetySettings, (*SafetySetting).toProto),
Tools: transformSlice(m.Tools, (*Tool).toProto),
ToolConfig: m.ToolConfig.toProto(),
GenerationConfig: m.GenerationConfig.toProto(),
SystemInstruction: m.SystemInstruction.toProto(),
}
})
}

func newUserContent(parts []Part) *Content {
Expand Down Expand Up @@ -244,7 +252,10 @@ func (iter *GenerateContentResponseIterator) Next() (*GenerateContentResponse, e
}

func protoToResponse(resp *pb.GenerateContentResponse) (*GenerateContentResponse, error) {
gcp := (GenerateContentResponse{}).fromProto(resp)
gcp, err := fromProto[GenerateContentResponse](resp)
if err != nil {
return nil, err
}
if gcp == nil {
return nil, errors.New("empty response from model")
}
Expand Down Expand Up @@ -273,20 +284,26 @@ func (iter *GenerateContentResponseIterator) MergedResponse() *GenerateContentRe

// CountTokens counts the number of tokens in the content.
func (m *GenerativeModel) CountTokens(ctx context.Context, parts ...Part) (*CountTokensResponse, error) {
req := m.newCountTokensRequest(newUserContent(parts))
req, err := m.newCountTokensRequest(newUserContent(parts))
if err != nil {
return nil, err
}
res, err := m.c.c.CountTokens(ctx, req)
if err != nil {
return nil, err
}

return (CountTokensResponse{}).fromProto(res), nil
return fromProto[CountTokensResponse](res)
}

func (m *GenerativeModel) newCountTokensRequest(contents ...*Content) *pb.CountTokensRequest {
func (m *GenerativeModel) newCountTokensRequest(contents ...*Content) (*pb.CountTokensRequest, error) {
gcr, err := m.newGenerateContentRequest(contents...)
if err != nil {
return nil, err
}
return &pb.CountTokensRequest{
Model: m.fullName,
GenerateContentRequest: m.newGenerateContentRequest(contents...),
}
GenerateContentRequest: gcr,
}, nil
}

// Info returns information about the model.
Expand All @@ -300,7 +317,7 @@ func (c *Client) modelInfo(ctx context.Context, fullName string) (*ModelInfo, er
if err != nil {
return nil, err
}
return (ModelInfo{}).fromProto(res), nil
return fromProto[ModelInfo](res)
}

// A BlockedError indicates that the model's response was blocked.
Expand Down Expand Up @@ -424,3 +441,8 @@ func transformSlice[From, To any](from []From, f func(From) To) []To {
}
return to
}

func fromProto[V interface{ fromProto(P) *V }, P any](p P) (*V, error) {
var v V
return pvCatchPanic(func() *V { return v.fromProto(p) })
}
13 changes: 13 additions & 0 deletions genai/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -766,3 +766,16 @@ func TestNoAPIKey(t *testing.T) {
t.Fatal("got nil, want error")
}
}

func TestRecoverPanic(t *testing.T) {
// Verify that conversions that used to cause a panic now result in an error.
fr := &FunctionResponse{
Name: "r",
Response: map[string]any{"x": 1 + 2i}, // complex values are invalid
}
var m GenerativeModel
_, err := m.newGenerateContentRequest(newUserContent([]Part{fr}))
if err == nil {
t.Fatal("got nil, want error")
}
}
50 changes: 35 additions & 15 deletions genai/generativelanguagepb_veneer.gen.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ func (v *BatchEmbedContentsResponse) toProto() *pb.BatchEmbedContentsResponse {
return nil
}
return &pb.BatchEmbedContentsResponse{
Embeddings: transformSlice(v.Embeddings, (*ContentEmbedding).toProto),
Embeddings: pvTransformSlice(v.Embeddings, (*ContentEmbedding).toProto),
}
}

Expand All @@ -50,7 +50,7 @@ func (BatchEmbedContentsResponse) fromProto(p *pb.BatchEmbedContentsResponse) *B
return nil
}
return &BatchEmbedContentsResponse{
Embeddings: transformSlice(p.Embeddings, (ContentEmbedding{}).fromProto),
Embeddings: pvTransformSlice(p.Embeddings, (ContentEmbedding{}).fromProto),
}
}

Expand Down Expand Up @@ -149,7 +149,7 @@ func (v *Candidate) toProto() *pb.Candidate {
Index: pvAddrOrNil(v.Index),
Content: v.Content.toProto(),
FinishReason: pb.Candidate_FinishReason(v.FinishReason),
SafetyRatings: transformSlice(v.SafetyRatings, (*SafetyRating).toProto),
SafetyRatings: pvTransformSlice(v.SafetyRatings, (*SafetyRating).toProto),
CitationMetadata: v.CitationMetadata.toProto(),
TokenCount: v.TokenCount,
}
Expand All @@ -163,7 +163,7 @@ func (Candidate) fromProto(p *pb.Candidate) *Candidate {
Index: pvDerefOrZero(p.Index),
Content: (Content{}).fromProto(p.Content),
FinishReason: FinishReason(p.FinishReason),
SafetyRatings: transformSlice(p.SafetyRatings, (SafetyRating{}).fromProto),
SafetyRatings: pvTransformSlice(p.SafetyRatings, (SafetyRating{}).fromProto),
CitationMetadata: (CitationMetadata{}).fromProto(p.CitationMetadata),
TokenCount: p.TokenCount,
}
Expand All @@ -180,7 +180,7 @@ func (v *CitationMetadata) toProto() *pb.CitationMetadata {
return nil
}
return &pb.CitationMetadata{
CitationSources: transformSlice(v.CitationSources, (*CitationSource).toProto),
CitationSources: pvTransformSlice(v.CitationSources, (*CitationSource).toProto),
}
}

Expand All @@ -189,7 +189,7 @@ func (CitationMetadata) fromProto(p *pb.CitationMetadata) *CitationMetadata {
return nil
}
return &CitationMetadata{
CitationSources: transformSlice(p.CitationSources, (CitationSource{}).fromProto),
CitationSources: pvTransformSlice(p.CitationSources, (CitationSource{}).fromProto),
}
}

Expand Down Expand Up @@ -256,7 +256,7 @@ func (v *Content) toProto() *pb.Content {
return nil
}
return &pb.Content{
Parts: transformSlice(v.Parts, partToProto),
Parts: pvTransformSlice(v.Parts, partToProto),
Role: v.Role,
}
}
Expand All @@ -266,7 +266,7 @@ func (Content) fromProto(p *pb.Content) *Content {
return nil
}
return &Content{
Parts: transformSlice(p.Parts, partFromProto),
Parts: pvTransformSlice(p.Parts, partFromProto),
Role: p.Role,
}
}
Expand Down Expand Up @@ -713,7 +713,7 @@ func (v *GenerateContentResponse) toProto() *pb.GenerateContentResponse {
return nil
}
return &pb.GenerateContentResponse{
Candidates: transformSlice(v.Candidates, (*Candidate).toProto),
Candidates: pvTransformSlice(v.Candidates, (*Candidate).toProto),
PromptFeedback: v.PromptFeedback.toProto(),
UsageMetadata: v.UsageMetadata.toProto(),
}
Expand All @@ -724,7 +724,7 @@ func (GenerateContentResponse) fromProto(p *pb.GenerateContentResponse) *Generat
return nil
}
return &GenerateContentResponse{
Candidates: transformSlice(p.Candidates, (Candidate{}).fromProto),
Candidates: pvTransformSlice(p.Candidates, (Candidate{}).fromProto),
PromptFeedback: (PromptFeedback{}).fromProto(p.PromptFeedback),
UsageMetadata: (UsageMetadata{}).fromProto(p.UsageMetadata),
}
Expand Down Expand Up @@ -1063,7 +1063,7 @@ func (v *PromptFeedback) toProto() *pb.GenerateContentResponse_PromptFeedback {
}
return &pb.GenerateContentResponse_PromptFeedback{
BlockReason: pb.GenerateContentResponse_PromptFeedback_BlockReason(v.BlockReason),
SafetyRatings: transformSlice(v.SafetyRatings, (*SafetyRating).toProto),
SafetyRatings: pvTransformSlice(v.SafetyRatings, (*SafetyRating).toProto),
}
}

Expand All @@ -1073,7 +1073,7 @@ func (PromptFeedback) fromProto(p *pb.GenerateContentResponse_PromptFeedback) *P
}
return &PromptFeedback{
BlockReason: BlockReason(p.BlockReason),
SafetyRatings: transformSlice(p.SafetyRatings, (SafetyRating{}).fromProto),
SafetyRatings: pvTransformSlice(p.SafetyRatings, (SafetyRating{}).fromProto),
}
}

Expand Down Expand Up @@ -1268,7 +1268,7 @@ func (v *Tool) toProto() *pb.Tool {
return nil
}
return &pb.Tool{
FunctionDeclarations: transformSlice(v.FunctionDeclarations, (*FunctionDeclaration).toProto),
FunctionDeclarations: pvTransformSlice(v.FunctionDeclarations, (*FunctionDeclaration).toProto),
}
}

Expand All @@ -1277,7 +1277,7 @@ func (Tool) fromProto(p *pb.Tool) *Tool {
return nil
}
return &Tool{
FunctionDeclarations: transformSlice(p.FunctionDeclarations, (FunctionDeclaration{}).fromProto),
FunctionDeclarations: pvTransformSlice(p.FunctionDeclarations, (FunctionDeclaration{}).fromProto),
}
}

Expand Down Expand Up @@ -1444,7 +1444,7 @@ func pvMapToStructPB(m map[string]any) *structpb.Struct {
}
s, err := structpb.NewStruct(m)
if err != nil {
panic(fmt.Errorf("support.MapToProto: %w", err))
panic(pvPanic(fmt.Errorf("pvMapToStructPB: %w", err)))
}
return s
}
Expand Down Expand Up @@ -1493,3 +1493,23 @@ func pvDurationFromProto(d *durationpb.Duration) time.Duration {
}
return d.AsDuration()
}

// pvPanic wraps panics from support functions.
// User-provided functions in the same package can also use it.
// It allows callers to distinguish conversion function panics from other panics.
type pvPanic error

// pvCatchPanic recovers from panics of type pvPanic and
// returns an error instead.
func pvCatchPanic[T any](f func() T) (_ T, err error) {
defer func() {
if r := recover(); r != nil {
if _, ok := r.(pvPanic); ok {
err = r.(error)
} else {
panic(r)
}
}
}()
return f(), nil
}
Loading

0 comments on commit 5220ad2

Please sign in to comment.