diff --git a/app/server/model/client.go b/app/server/model/client.go index ed145575..04819faf 100644 --- a/app/server/model/client.go +++ b/app/server/model/client.go @@ -2,7 +2,9 @@ package model import ( "context" + "fmt" "log" + "regexp" "strings" "time" @@ -68,6 +70,15 @@ func createChatCompletionStream( // for retriable errors, retry with exponential backoff if numRetry < 5 { + // check if the error message contains a retry duration + if duration := parseRetryAfter(err.Error()); duration != nil { + log.Printf("Retry duration found: %v\n", *duration) + // wait for the duration times 1.5 to give some buffer + waitDuration := time.Duration(float64(*duration) * 1.5) + time.Sleep(waitDuration) + return createChatCompletionStream(client, ctx, req, numRetry+1) + } + waitBackoff(numRetry) return createChatCompletionStream(client, ctx, req, numRetry+1) } @@ -109,6 +120,15 @@ func createChatCompletion( // for retriable errors, retry with exponential backoff if numRetry < 5 { + // check if the error message contains a retry duration + if duration := parseRetryAfter(err.Error()); duration != nil { + log.Printf("Retry duration found: %v\n", *duration) + // wait for the duration times 1.5 to give some buffer + waitDuration := time.Duration(float64(*duration) * 1.5) + time.Sleep(waitDuration) + return createChatCompletion(client, ctx, req, numRetry+1) + } + waitBackoff(numRetry) return createChatCompletion(client, ctx, req, numRetry+1) } @@ -153,3 +173,20 @@ func waitBackoff(numRetry int) { log.Printf("Retrying in %v\n", d) time.Sleep(d) } + +// parseRetryAfter takes an error message and returns the retry duration or nil if no duration is found. +func parseRetryAfter(errorMessage string) *time.Duration { + // Regex pattern to find the duration in seconds or milliseconds + pattern := regexp.MustCompile(`try again in (\d+(\.\d+)?(ms|s))`) + match := pattern.FindStringSubmatch(errorMessage) + if len(match) > 1 { + durationStr := match[1] // the duration string including the unit + duration, err := time.ParseDuration(durationStr) + if err != nil { + fmt.Println("Error parsing duration:", err) + return nil + } + return &duration + } + return nil +}