From 5508952c591e9ca2700d8c7f0567203e660a29e4 Mon Sep 17 00:00:00 2001 From: Mike Kistler Date: Sun, 23 Apr 2023 16:57:06 -0500 Subject: [PATCH] Add GetCompletionsSSE and supporting methods --- .../azopenai/custom_client.go | 75 +++++++++++++++++++ 1 file changed, 75 insertions(+) diff --git a/sdk/cognitiveservices/azopenai/custom_client.go b/sdk/cognitiveservices/azopenai/custom_client.go index b2b17733b9d3..1f07869cc896 100644 --- a/sdk/cognitiveservices/azopenai/custom_client.go +++ b/sdk/cognitiveservices/azopenai/custom_client.go @@ -9,9 +9,18 @@ package azopenai // this file contains handwritten additions to the generated code import ( + "bufio" + "context" + "encoding/json" + "errors" + "io" + "net/http" + "strings" + "github.com/Azure/azure-sdk-for-go/sdk/azcore" "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" "github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/to" ) const ( @@ -58,3 +67,69 @@ func NewClientWithKeyCredential(endpoint string, credential KeyCredential, optio } return &Client{endpoint: endpoint + "/openai", internal: azcoreClient}, nil } + +// Support for SSE + +// Generated from API version 2022-12-01 +// - options - ClientGetCompletionsOptions contains the optional parameters for the Client.GetCompletions method. +func (client *Client) GetCompletionsSSE(ctx context.Context, deploymentID string, body CompletionRequest, options *ClientGetCompletionsOptions) (*http.Response, error) { + body.Stream = to.Ptr(true) + req, err := client.getCompletionsCreateRequest(ctx, deploymentID, body, options) + if err != nil { + return nil, err + } + resp, err := client.internal.Pipeline().Do(req) + if err != nil { + return nil, err + } + if !runtime.HasStatusCode(resp, http.StatusOK) { + return nil, runtime.NewResponseError(resp) + } + return resp, nil +} + +type EventReader[T any] struct { + reader io.Reader // Required for Closing + scanner *bufio.Scanner +} + +func NewEventReader[T any](r io.Reader) *EventReader[T] { + return &EventReader[T]{reader: r, scanner: bufio.NewScanner(r)} +} + +func (er *EventReader[T]) Read() (T, error) { + // https://html.spec.whatwg.org/multipage/server-sent-events.html + for er.scanner.Scan() { // Scan while no error + line := er.scanner.Text() // Get the line & interpret the event stream: + + if line == "" || line[0] == ':' { // If the line is blank or is a comment, skip it + continue + } + + if strings.Contains(line, ":") { // If the line contains a U+003A COLON character (:), process the field + tokens := strings.SplitN(line, ":", 2) + tokens[0], tokens[1] = strings.TrimSpace(tokens[0]), strings.TrimSpace(tokens[1]) + var data T + switch tokens[0] { + case "data": // return the deserialized JSON object + if tokens[1] == "[DONE]" { // If data is [DONE], end of stream was reached + return data, io.EOF + } + //fmt.Println(tokens[1]) + err := json.Unmarshal([]byte(tokens[1]), &data) + return data, err + + default: // Any other event type is an unexpected + return data, errors.New("Unexpected event type: " + tokens[0]) + } + // Unreachable + } + } + return *new(T), er.scanner.Err() +} + +func (er *EventReader[T]) Close() { + if closer, ok := er.reader.(io.Closer); ok { + closer.Close() + } +}