diff --git a/sdk/cognitiveservices/azopenai/client_test.go b/sdk/cognitiveservices/azopenai/client_test.go new file mode 100644 index 000000000000..c1dbdce0387e --- /dev/null +++ b/sdk/cognitiveservices/azopenai/client_test.go @@ -0,0 +1,222 @@ +//go:build go1.18 +// +build go1.18 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. See License.txt in the project root for license information. + +package azopenai + +import ( + "context" + "log" + "os" + "testing" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore/to" + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" +) + +var ( + endpoint = os.Getenv("AOAI_ENDPOINT") + apiKey = os.Getenv("AOAI_API_KEY") +) + +func TestClient_GetChatCompletions(t *testing.T) { + type args struct { + ctx context.Context + deploymentID string + body ChatCompletionsOptions + options *ClientGetChatCompletionsOptions + } + cred := KeyCredential{APIKey: apiKey} + deploymentID := "gpt-35-turbo" + chatClient, err := NewClientWithKeyCredential(endpoint, cred, deploymentID, nil) + if err != nil { + log.Fatalf("%v", err) + } + tests := []struct { + name string + client *Client + args args + want ClientGetChatCompletionsResponse + wantErr bool + }{ + { + name: "ChatCompletions", + client: chatClient, + args: args{ + ctx: context.TODO(), + deploymentID: "gpt-35-turbo", + body: ChatCompletionsOptions{ + Messages: []*ChatMessage{ + { + Role: to.Ptr(ChatRole("user")), + Content: to.Ptr("Count to 100, with a comma between each number and no newlines. E.g., 1, 2, 3, ..."), + }, + }, + MaxTokens: to.Ptr(int32(1024)), + Temperature: to.Ptr(float32(0.0)), + }, + options: nil, + }, + want: ClientGetChatCompletionsResponse{ + ChatCompletions: ChatCompletions{ + Choices: []*ChatChoice{ + { + Message: &ChatChoiceMessage{ + Role: to.Ptr(ChatRole("assistant")), + Content: to.Ptr("1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100."), + }, + Index: to.Ptr(int32(0)), + FinishReason: to.Ptr(CompletionsFinishReason("stop")), + }, + }, + Usage: &CompletionsUsage{ + CompletionTokens: to.Ptr(int32(299)), + PromptTokens: to.Ptr(int32(37)), + TotalTokens: to.Ptr(int32(336)), + }, + }, + }, + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := tt.client.GetChatCompletions(tt.args.ctx, tt.args.body, tt.args.options) + if (err != nil) != tt.wantErr { + t.Errorf("Client.GetChatCompletions() error = %v, wantErr %v", err, tt.wantErr) + return + } + opts := cmpopts.IgnoreFields(ChatCompletions{}, "Created", "ID") + if diff := cmp.Diff(tt.want.ChatCompletions, got.ChatCompletions, opts); diff != "" { + t.Errorf("Client.GetCompletions(): -want, +got:\n%s", diff) + } + }) + } +} + +func TestClient_GetCompletions(t *testing.T) { + type args struct { + ctx context.Context + deploymentID string + body CompletionsOptions + options *ClientGetCompletionsOptions + } + cred := KeyCredential{APIKey: apiKey} + deploymentID := "text-davinci-003" + client, err := NewClientWithKeyCredential(endpoint, cred, deploymentID, nil) + if err != nil { + log.Fatalf("%v", err) + } + tests := []struct { + name string + client *Client + args args + want ClientGetCompletionsResponse + wantErr bool + }{ + { + name: "chatbot", + client: client, + args: args{ + ctx: context.TODO(), + deploymentID: deploymentID, + body: CompletionsOptions{ + Prompt: []*string{to.Ptr("What is Azure OpenAI?")}, + MaxTokens: to.Ptr(int32(2048 - 127)), + Temperature: to.Ptr(float32(0.0)), + }, + options: nil, + }, + want: ClientGetCompletionsResponse{ + Completions: Completions{ + Choices: []*Choice{ + { + Text: to.Ptr("\n\nAzure OpenAI is a platform from Microsoft that provides access to OpenAI's artificial intelligence (AI) technologies. It enables developers to build, train, and deploy AI models in the cloud. Azure OpenAI provides access to OpenAI's powerful AI technologies, such as GPT-3, which can be used to create natural language processing (NLP) applications, computer vision models, and reinforcement learning models."), + Index: to.Ptr(int32(0)), + FinishReason: to.Ptr(CompletionsFinishReason("stop")), + Logprobs: nil, + }, + }, + Usage: &CompletionsUsage{ + CompletionTokens: to.Ptr(int32(85)), + PromptTokens: to.Ptr(int32(6)), + TotalTokens: to.Ptr(int32(91)), + }, + }, + }, + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := tt.client.GetCompletions(tt.args.ctx, tt.args.body, tt.args.options) + if (err != nil) != tt.wantErr { + t.Errorf("Client.GetCompletions() error = %v, wantErr %v", err, tt.wantErr) + return + } + opts := cmpopts.IgnoreFields(Completions{}, "Created", "ID") + if diff := cmp.Diff(tt.want.Completions, got.Completions, opts); diff != "" { + t.Errorf("Client.GetCompletions(): -want, +got:\n%s", diff) + } + }) + } +} + +func TestClient_GetEmbeddings(t *testing.T) { + type args struct { + ctx context.Context + deploymentID string + body EmbeddingsOptions + options *ClientGetEmbeddingsOptions + } + deploymentID := "embedding" + cred := KeyCredential{APIKey: apiKey} + client, err := NewClientWithKeyCredential(endpoint, cred, deploymentID, nil) + if err != nil { + log.Fatalf("%v", err) + } + tests := []struct { + name string + client *Client + args args + want ClientGetEmbeddingsResponse + wantErr bool + }{ + { + name: "Embeddings", + client: client, + args: args{ + ctx: context.TODO(), + deploymentID: "embedding", + body: EmbeddingsOptions{ + Input: "Your text string goes here", + Model: to.Ptr("text-similarity-curie-001"), + }, + options: nil, + }, + want: ClientGetEmbeddingsResponse{ + Embeddings{ + Data: []*EmbeddingItem{}, + Usage: &EmbeddingsUsage{}, + }, + }, + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := tt.client.GetEmbeddings(tt.args.ctx, tt.args.body, tt.args.options) + if (err != nil) != tt.wantErr { + t.Errorf("Client.GetEmbeddings() error = %v, wantErr %v", err, tt.wantErr) + return + } + if len(got.Embeddings.Data[0].Embedding) != 4096 { + t.Errorf("Client.GetEmbeddings() len(Data) want 4096, got %d", len(got.Embeddings.Data)) + return + } + }) + } +} diff --git a/sdk/cognitiveservices/azopenai/custom_client_test.go b/sdk/cognitiveservices/azopenai/custom_client_test.go new file mode 100644 index 000000000000..ba0496799e95 --- /dev/null +++ b/sdk/cognitiveservices/azopenai/custom_client_test.go @@ -0,0 +1,125 @@ +//go:build go1.18 +// +build go1.18 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. See License.txt in the project root for license information. + +package azopenai + +import ( + "context" + "io" + "reflect" + "strings" + "testing" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/to" +) + +func TestNewClient(t *testing.T) { + type args struct { + endpoint string + credential azcore.TokenCredential + deploymentID string + options *ClientOptions + } + tests := []struct { + name string + args args + want *Client + wantErr bool + }{ + // TODO: Add test cases. + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := NewClient(tt.args.endpoint, tt.args.credential, tt.args.deploymentID, tt.args.options) + if (err != nil) != tt.wantErr { + t.Errorf("NewClient() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("NewClient() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestNewClientWithKeyCredential(t *testing.T) { + type args struct { + endpoint string + credential KeyCredential + deploymentID string + options *ClientOptions + } + tests := []struct { + name string + args args + want *Client + wantErr bool + }{ + // TODO: Add test cases. + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := NewClientWithKeyCredential(tt.args.endpoint, tt.args.credential, tt.args.deploymentID, tt.args.options) + if (err != nil) != tt.wantErr { + t.Errorf("NewClientWithKeyCredential() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("NewClientWithKeyCredential() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestClient_GetCompletionsStream(t *testing.T) { + body := CompletionsOptions{ + Prompt: []*string{to.Ptr("What is Azure OpenAI?")}, + MaxTokens: to.Ptr(int32(2048 - 127)), + Temperature: to.Ptr(float32(0.0)), + } + cred := KeyCredential{APIKey: apiKey} + deploymentID := "text-davinci-003" + client, err := NewClientWithKeyCredential(endpoint, cred, deploymentID, nil) + if err != nil { + t.Errorf("NewClientWithKeyCredential() error = %v", err) + return + } + response, err := client.GetCompletionsStream(context.TODO(), body, nil) + if err != nil { + t.Errorf("Client.GetCompletionsStream() error = %v", err) + return + } + reader := response.Events + defer reader.Close() + + var sb strings.Builder + var eventCount int + for { + event, err := reader.Read() + if err == io.EOF { + break + } + eventCount++ + if err != nil { + t.Errorf("reader.Read() error = %v", err) + return + } + sb.WriteString(*event.Choices[0].Text) + } + got := sb.String() + const want = "\n\nAzure OpenAI is a platform from Microsoft that provides access to OpenAI's artificial intelligence (AI) technologies. It enables developers to build, train, and deploy AI models in the cloud. Azure OpenAI provides access to OpenAI's powerful AI technologies, such as GPT-3, which can be used to create natural language processing (NLP) applications, computer vision models, and reinforcement learning models." + if got != want { + i := 0 + for i < len(got) && i < len(want) && got[i] == want[i] { + i++ + } + t.Errorf("Client.GetCompletionsStream() text[%d] = %c, want %c", i, got[i], want[i]) + } + if eventCount != 86 { + t.Errorf("Client.GetCompletionsStream() got = %v, want %v", eventCount, 1) + } +} diff --git a/sdk/cognitiveservices/azopenai/go.mod b/sdk/cognitiveservices/azopenai/go.mod index e051c62950e0..6f3ac167661e 100644 --- a/sdk/cognitiveservices/azopenai/go.mod +++ b/sdk/cognitiveservices/azopenai/go.mod @@ -2,10 +2,13 @@ module github.com/Azure/azure-sdk-for-go/sdk/cognitiveservices/azopenai go 1.18 -require github.com/Azure/azure-sdk-for-go/sdk/azcore v1.4.0 +require ( + github.com/Azure/azure-sdk-for-go/sdk/azcore v1.6.0 + github.com/Azure/azure-sdk-for-go/sdk/internal v1.3.0 + github.com/google/go-cmp v0.5.9 +) require ( - github.com/Azure/azure-sdk-for-go/sdk/internal v1.1.2 // indirect - golang.org/x/net v0.7.0 // indirect - golang.org/x/text v0.7.0 // indirect + golang.org/x/net v0.8.0 // indirect + golang.org/x/text v0.8.0 // indirect ) diff --git a/sdk/cognitiveservices/azopenai/go.sum b/sdk/cognitiveservices/azopenai/go.sum index d39a720dafd8..fb24a34f3b79 100644 --- a/sdk/cognitiveservices/azopenai/go.sum +++ b/sdk/cognitiveservices/azopenai/go.sum @@ -1,12 +1,14 @@ -github.com/Azure/azure-sdk-for-go/sdk/azcore v1.4.0 h1:rTnT/Jrcm+figWlYz4Ixzt0SJVR2cMC8lvZcimipiEY= -github.com/Azure/azure-sdk-for-go/sdk/azcore v1.4.0/go.mod h1:ON4tFdPTwRcgWEaVDrN3584Ef+b7GgSJaXxe5fW9t4M= -github.com/Azure/azure-sdk-for-go/sdk/internal v1.1.2 h1:+5VZ72z0Qan5Bog5C+ZkgSqUbeVUd9wgtHOrIKuc5b8= -github.com/Azure/azure-sdk-for-go/sdk/internal v1.1.2/go.mod h1:eWRD7oawr1Mu1sLCawqVc0CUiF43ia3qQMxLscsKQ9w= +github.com/Azure/azure-sdk-for-go/sdk/azcore v1.6.0 h1:8kDqDngH+DmVBiCtIjCFTGa7MBnsIOkF9IccInFEbjk= +github.com/Azure/azure-sdk-for-go/sdk/azcore v1.6.0/go.mod h1:bjGvMhVMb+EEm3VRNQawDMUyMMjo+S5ewNjflkep/0Q= +github.com/Azure/azure-sdk-for-go/sdk/internal v1.3.0 h1:sXr+ck84g/ZlZUOZiNELInmMgOsuGwdjjVkEIde0OtY= +github.com/Azure/azure-sdk-for-go/sdk/internal v1.3.0/go.mod h1:okt5dMMTOFjX/aovMlrjvvXoPMBVSPzk9185BT0+eZM= github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= +github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= +github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY= -golang.org/x/net v0.7.0 h1:rJrUqqhjsgNp7KqAIc25s9pZnjU7TUcSY7HcVZjdn1g= -golang.org/x/net v0.7.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= -golang.org/x/text v0.7.0 h1:4BRB4x83lYWy72KwLD/qYDuTu7q9PjSagHvijDw7cLo= -golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= +golang.org/x/net v0.8.0 h1:Zrh2ngAOFYneWTAIAPethzeaQLuHwhuBkuV6ZiRnUaQ= +golang.org/x/net v0.8.0/go.mod h1:QVkue5JL9kW//ek3r6jTKnTFis1tRmNAW2P1shuFdJc= +golang.org/x/text v0.8.0 h1:57P1ETyNKtuIjB4SRd15iJxuhj8Gc416Y78H3qgMh68= +golang.org/x/text v0.8.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo= diff --git a/sdk/cognitiveservices/azopenai/policy_apikey_test.go b/sdk/cognitiveservices/azopenai/policy_apikey_test.go new file mode 100644 index 000000000000..d688530dd896 --- /dev/null +++ b/sdk/cognitiveservices/azopenai/policy_apikey_test.go @@ -0,0 +1,80 @@ +//go:build go1.18 +// +build go1.18 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. See License.txt in the project root for license information. + +package azopenai + +import ( + "context" + "net/http" + "reflect" + "testing" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime" + + "github.com/Azure/azure-sdk-for-go/sdk/internal/mock" +) + +func TestNewAPIKeyPolicy(t *testing.T) { + type args struct { + header string + cred KeyCredential + } + simpleCred := KeyCredential{APIKey: "apiKey"} + simpleHeader := "headerName" + tests := []struct { + name string + args args + want *APIKeyPolicy + }{ + { + name: "simple", + args: args{ + cred: simpleCred, + header: simpleHeader, + }, + want: &APIKeyPolicy{ + header: simpleHeader, + cred: simpleCred, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := NewAPIKeyPolicy(tt.args.cred, tt.args.header); !reflect.DeepEqual(got, tt.want) { + t.Errorf("NewAPIKeyPolicy() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestAPIKeyPolicy_Success(t *testing.T) { + srv, close := mock.NewTLSServer() + defer close() + srv.AppendResponse(mock.WithStatusCode(http.StatusOK)) + cred := KeyCredential{ + APIKey: "secret", + } + authPolicy := NewAPIKeyPolicy(cred, "api-key") + pipeline := runtime.NewPipeline( + "testmodule", + "v0.1.0", + runtime.PipelineOptions{PerRetry: []policy.Policy{authPolicy}}, + &policy.ClientOptions{ + Transport: srv, + }) + req, err := runtime.NewRequest(context.Background(), http.MethodGet, srv.URL()) + if err != nil { + t.Fatal(err) + } + resp, err := pipeline.Do(req) + if err != nil { + t.Fatalf("Expected nil error but received one") + } + if hdrValue := resp.Request.Header.Get("api-key"); hdrValue != "secret" { + t.Fatalf("expected api-key '%s', got '%s'", "secret", hdrValue) + } +}