Skip to content

Commit

Permalink
Added the ability to provision AOAI as an optional component (#46570)
Browse files Browse the repository at this point in the history
* fixed cminfra

* openai provisioning works

* open ai client added

* progress

* changed ai to key auth

* moved CM to WorkspaceClient abstraction

* refactored built-in methods and extension methods

* openai works

* updated api file

* disabled live tests

* updated version

* small tweaks

* updated api file

* PR feedback

* removed stj override
  • Loading branch information
KrzysztofCwalina authored Oct 14, 2024
1 parent 8a252ab commit fabfa6c
Show file tree
Hide file tree
Showing 22 changed files with 679 additions and 290 deletions.
2 changes: 2 additions & 0 deletions sdk/provisioning/Azure.Provisioning.CloudMachine/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

## 1.0.0-beta.1 (Unreleased)

## 1.0.0-beta.2 (Unreleased)

### Features Added

### Breaking Changes
Expand Down
Original file line number Diff line number Diff line change
@@ -1,52 +1,83 @@
namespace Azure.CloudMachine
{
public partial class ClientCache
public partial class CloudMachineClient : Azure.CloudMachine.CloudMachineWorkspace
{
public ClientCache() { }
public T Get<T>(string id, System.Func<T> value) where T : class { throw null; }
public CloudMachineClient(Azure.Core.TokenCredential? credential = null, Microsoft.Extensions.Configuration.IConfiguration? configuration = null) : base (default(Azure.Core.TokenCredential), default(Microsoft.Extensions.Configuration.IConfiguration)) { }
public Azure.CloudMachine.MessagingServices Messaging { get { throw null; } }
public Azure.CloudMachine.StorageServices Storage { get { throw null; } }
}
public partial class CloudMachineClient
public partial class CloudMachineWorkspace : Azure.Core.WorkspaceClient
{
protected CloudMachineClient() { }
public CloudMachineClient(Azure.Identity.DefaultAzureCredential? credential = null, Microsoft.Extensions.Configuration.IConfiguration? configuration = null) { }
public CloudMachineWorkspace(Azure.Core.TokenCredential? credential = null, Microsoft.Extensions.Configuration.IConfiguration? configuration = null) { }
[System.ComponentModel.EditorBrowsableAttribute(System.ComponentModel.EditorBrowsableState.Never)]
public Azure.CloudMachine.ClientCache ClientCache { get { throw null; } }
public Azure.Core.TokenCredential Credential { get { throw null; } }
public string Id { get { throw null; } }
public override Azure.Core.TokenCredential Credential { get { throw null; } }
[System.ComponentModel.EditorBrowsableAttribute(System.ComponentModel.EditorBrowsableState.Never)]
public Azure.CloudMachine.CloudMachineClient.CloudMachineProperties Properties { get { throw null; } }
public string Id { get { throw null; } }
[System.ComponentModel.EditorBrowsableAttribute(System.ComponentModel.EditorBrowsableState.Never)]
public override bool Equals(object? obj) { throw null; }
[System.ComponentModel.EditorBrowsableAttribute(System.ComponentModel.EditorBrowsableState.Never)]
public override Azure.Core.ClientConfiguration? GetConfiguration(string clientId, string? instanceId = null) { throw null; }
[System.ComponentModel.EditorBrowsableAttribute(System.ComponentModel.EditorBrowsableState.Never)]
public override int GetHashCode() { throw null; }
[System.ComponentModel.EditorBrowsableAttribute(System.ComponentModel.EditorBrowsableState.Never)]
public override string ToString() { throw null; }
[System.Runtime.InteropServices.StructLayoutAttribute(System.Runtime.InteropServices.LayoutKind.Sequential)]
public partial struct CloudMachineProperties
{
private object _dummy;
private int _dummyPrimitive;
public System.Uri BlobServiceUri { get { throw null; } }
public System.Uri DefaultContainerUri { get { throw null; } }
public System.Uri KeyVaultUri { get { throw null; } }
public string ServiceBusNamespace { get { throw null; } }
}
}
public static partial class MessagingServices
[System.Runtime.InteropServices.StructLayoutAttribute(System.Runtime.InteropServices.LayoutKind.Sequential)]
public readonly partial struct MessagingServices
{
public static void Send(this Azure.CloudMachine.CloudMachineClient cm, object serializable) { }
private readonly object _dummy;
private readonly int _dummyPrimitive;
public void SendMessage(object serializable) { }
public void WhenMessageReceived(System.Action<string> received) { }
}
public static partial class StorageServices
[System.Runtime.InteropServices.StructLayoutAttribute(System.Runtime.InteropServices.LayoutKind.Sequential)]
public readonly partial struct StorageServices
{
public static System.BinaryData Download(this Azure.CloudMachine.CloudMachineClient cm, string name) { throw null; }
public static string Upload(this Azure.CloudMachine.CloudMachineClient cm, object json, string? name = null) { throw null; }
private readonly object _dummy;
private readonly int _dummyPrimitive;
public System.BinaryData DownloadBlob(string name) { throw null; }
public string UploadBlob(object json, string? name = null) { throw null; }
public void WhenBlobCreated(System.Func<string, System.Threading.Tasks.Task> function) { }
public void WhenBlobUploaded(System.Action<string> function) { }
}
}
namespace Azure.Core
{
public partial class ClientCache
{
public ClientCache() { }
public T Get<T>(string id, System.Func<T> value) where T : class { throw null; }
}
[System.Runtime.InteropServices.StructLayoutAttribute(System.Runtime.InteropServices.LayoutKind.Sequential)]
public readonly partial struct ClientConfiguration
{
private readonly object _dummy;
private readonly int _dummyPrimitive;
public ClientConfiguration(string endpoint, string? apiKey = null) { throw null; }
public string? ApiKey { get { throw null; } }
public Azure.Core.CredentialType CredentialType { get { throw null; } }
public string Endpoint { get { throw null; } }
}
public enum CredentialType
{
EntraId = 0,
ApiKey = 1,
}
public abstract partial class WorkspaceClient
{
protected WorkspaceClient() { }
public abstract Azure.Core.TokenCredential Credential { get; }
[System.ComponentModel.EditorBrowsableAttribute(System.ComponentModel.EditorBrowsableState.Never)]
public Azure.Core.ClientCache Subclients { get { throw null; } }
public abstract Azure.Core.ClientConfiguration? GetConfiguration(string clientId, string? instanceId = null);
}
}
namespace Azure.Provisioning.CloudMachine
{
public abstract partial class CloudMachineFeature
{
protected CloudMachineFeature() { }
[System.ComponentModel.EditorBrowsableAttribute(System.ComponentModel.EditorBrowsableState.Never)]
public abstract void AddTo(Azure.Provisioning.CloudMachine.CloudMachineInfrastructure cm);
}
public partial class CloudMachineInfrastructure
Expand All @@ -65,24 +96,26 @@ namespace Azure.Provisioning.CloudMachine.KeyVault
{
public static partial class KeyVaultExtensions
{
public static Azure.Security.KeyVault.Secrets.SecretClient GetKeyVaultSecretClient(this Azure.CloudMachine.CloudMachineClient client) { throw null; }
public static Azure.Security.KeyVault.Secrets.SecretClient GetKeyVaultSecretsClient(this Azure.Core.WorkspaceClient workspace) { throw null; }
}
public partial class KeyVaultFeature : Azure.Provisioning.CloudMachine.CloudMachineFeature
{
public KeyVaultFeature() { }
public KeyVaultFeature(Azure.Provisioning.KeyVault.KeyVaultSku? sku = null) { }
public Azure.Provisioning.KeyVault.KeyVaultSku Sku { get { throw null; } set { } }
public override void AddTo(Azure.Provisioning.CloudMachine.CloudMachineInfrastructure cm) { }
public override void AddTo(Azure.Provisioning.CloudMachine.CloudMachineInfrastructure infrastructure) { }
}
}
namespace Azure.Provisioning.CloudMachine.OpenAI
{
public partial class OpenAIFeature : Azure.Provisioning.CloudMachine.CloudMachineFeature
{
public OpenAIFeature() { }
public override void AddTo(Azure.Provisioning.CloudMachine.CloudMachineInfrastructure cm) { }
public OpenAIFeature(string model, string modelVersion) { }
public string Model { get { throw null; } }
public string ModelVersion { get { throw null; } }
public override void AddTo(Azure.Provisioning.CloudMachine.CloudMachineInfrastructure cloudMachine) { }
}
public static partial class OpenAIFeatureExtensions
{
public static Azure.Security.KeyVault.Secrets.SecretClient GetOpenAIClient(this Azure.CloudMachine.CloudMachineClient client) { throw null; }
public static OpenAI.Chat.ChatClient GetOpenAIChatClient(this Azure.Core.WorkspaceClient workspace) { throw null; }
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

<PropertyGroup>
<Description>Azure.Provisioning.CloudMachine simplifies declarative resource provisioning in .NET.</Description>
<Version>1.0.0-beta.1</Version>
<Version>1.0.0-beta.2</Version>
<TargetFrameworks>$(RequiredTargetFrameworks)</TargetFrameworks>
<LangVersion>12</LangVersion>

Expand All @@ -11,6 +11,7 @@
</PropertyGroup>

<ItemGroup>
<PackageReference Include="Azure.AI.OpenAI" />
<PackageReference Include="Azure.Identity" />
<PackageReference Include="Azure.Messaging.ServiceBus" />
<PackageReference Include="Azure.Security.KeyVault.Secrets" />
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

using System;
using Azure.Core;
using Azure.Provisioning.Authorization;
using Azure.Provisioning.Expressions;
using Azure.Provisioning.KeyVault;
using Azure.Security.KeyVault.Secrets;

namespace Azure.Provisioning.CloudMachine.KeyVault;

public class KeyVaultFeature : CloudMachineFeature
{
public KeyVaultSku Sku { get; set; }

public KeyVaultFeature(KeyVaultSku? sku = default)
{
if (sku == null)
{
sku = new KeyVaultSku { Name = KeyVaultSkuName.Standard, Family = KeyVaultSkuFamily.A, };
}
Sku = sku;
}
public override void AddTo(CloudMachineInfrastructure infrastructure)
{
// Add a KeyVault to the CloudMachine infrastructure.
KeyVaultService keyVaultResource = new("cm_kv")
{
Name = infrastructure.Id,
Properties =
new KeyVaultProperties
{
Sku = this.Sku,
TenantId = BicepFunction.GetSubscription().TenantId,
EnabledForDeployment = true,
AccessPolicies = [
new KeyVaultAccessPolicy() {
ObjectId = infrastructure.PrincipalIdParameter,
Permissions = new IdentityAccessPermissions() {
Secrets = [IdentityAccessSecretPermission.Get, IdentityAccessSecretPermission.Set]
},
TenantId = infrastructure.Identity.TenantId
}
]
},
};

infrastructure.AddResource(keyVaultResource);

RoleAssignment ra = keyVaultResource.CreateRoleAssignment(KeyVaultBuiltInRole.KeyVaultAdministrator, RoleManagementPrincipalType.User, infrastructure.PrincipalIdParameter);
infrastructure.AddResource(ra);

// necessary until ResourceName is settable via AssignRole.
RoleAssignment kvMiRoleAssignment = new RoleAssignment(keyVaultResource.IdentifierName + "_" + infrastructure.Identity.IdentifierName + "_" + KeyVaultBuiltInRole.GetBuiltInRoleName(KeyVaultBuiltInRole.KeyVaultAdministrator));
kvMiRoleAssignment.Name = BicepFunction.CreateGuid(keyVaultResource.Id, infrastructure.Identity.Id, BicepFunction.GetSubscriptionResourceId("Microsoft.Authorization/roleDefinitions", KeyVaultBuiltInRole.KeyVaultAdministrator.ToString()));
kvMiRoleAssignment.Scope = new IdentifierExpression(keyVaultResource.IdentifierName);
kvMiRoleAssignment.PrincipalType = RoleManagementPrincipalType.ServicePrincipal;
kvMiRoleAssignment.RoleDefinitionId = BicepFunction.GetSubscriptionResourceId("Microsoft.Authorization/roleDefinitions", KeyVaultBuiltInRole.KeyVaultAdministrator.ToString());
kvMiRoleAssignment.PrincipalId = infrastructure.Identity.PrincipalId;
infrastructure.AddResource(kvMiRoleAssignment);
}
}

public static class KeyVaultExtensions
{
public static SecretClient GetKeyVaultSecretsClient(this WorkspaceClient workspace)
{
ClientConfiguration? connectionMaybe = workspace.GetConfiguration(typeof(SecretClient).FullName);
if (connectionMaybe == null)
{
throw new Exception("Connection not found");
}

ClientConfiguration connection = connectionMaybe.Value;
if (connection.CredentialType == CredentialType.EntraId)
{
return new(new Uri(connection.Endpoint), workspace.Credential);
}
throw new Exception("ApiKey not supported");
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

using System;
using System.ClientModel;
using System.ClientModel.Primitives;
using System.Diagnostics.Contracts;
using Azure.AI.OpenAI;
using Azure.CloudMachine;
using Azure.Core;
using Azure.Provisioning.Authorization;
using Azure.Provisioning.CognitiveServices;
using OpenAI.Chat;

namespace Azure.Provisioning.CloudMachine.OpenAI;

public class OpenAIFeature : CloudMachineFeature
{
public string Model { get; }
public string ModelVersion { get; }

public OpenAIFeature(string model, string modelVersion) { Model = model; ModelVersion = modelVersion; }

public override void AddTo(CloudMachineInfrastructure cloudMachine)
{
CognitiveServicesAccount cognitiveServices = new("openai")
{
Name = cloudMachine.Id,
Kind = "OpenAI",
Sku = new CognitiveServicesSku { Name = "S0" },
Properties = new CognitiveServicesAccountProperties()
{
PublicNetworkAccess = ServiceAccountPublicNetworkAccess.Enabled,
CustomSubDomainName = cloudMachine.Id
},
};

cloudMachine.AddResource(cognitiveServices.CreateRoleAssignment(
CognitiveServicesBuiltInRole.CognitiveServicesOpenAIContributor,
RoleManagementPrincipalType.User,
cloudMachine.PrincipalIdParameter)
);

// TODO: if we every support more than one deployment, they need to be chained using DependsOn.
// The reason is that deployments need to be deployed/created serially.
CognitiveServicesAccountDeployment deployment = new("openai_deployment", "2023-05-01")
{
Parent = cognitiveServices,
Name = cloudMachine.Id,
Properties = new CognitiveServicesAccountDeploymentProperties()
{
Model = new CognitiveServicesAccountDeploymentModel() {
Name = this.Model,
Format = "OpenAI",
Version = this.ModelVersion
}
},
};

cloudMachine.AddResource(cognitiveServices);
cloudMachine.AddResource(deployment);
}
}

public static class OpenAIFeatureExtensions
{
public static ChatClient GetOpenAIChatClient(this WorkspaceClient workspace)
{
string chatClientId = typeof(ChatClient).FullName;

ChatClient client = workspace.Subclients.Get(chatClientId, () =>
{
string azureOpenAIClientId = typeof(AzureOpenAIClient).FullName;
AzureOpenAIClient aoia = workspace.Subclients.Get(azureOpenAIClientId, () =>
{
ClientConfiguration? connectionMaybe = workspace.GetConfiguration(typeof(AzureOpenAIClient).FullName);
if (connectionMaybe == null) throw new Exception("Connection not found");
ClientConfiguration connection = connectionMaybe.Value;
Uri endpoint = new(connection.Endpoint);
var clientOptions = new AzureOpenAIClientOptions();
if (connection.CredentialType == CredentialType.EntraId)
{
AzureOpenAIClient aoai = new(endpoint, workspace.Credential, clientOptions);
return aoai;
}
else
{
AzureOpenAIClient aoai = new(endpoint, new ApiKeyCredential(connection.ApiKey!), clientOptions);
return aoai;
}
});
string azureOpenAIChatClientId = typeof(ChatClient).FullName;
ClientConfiguration? connectionMaybe = workspace.GetConfiguration(azureOpenAIChatClientId);
if (connectionMaybe == null) throw new Exception("Connection not found");
var connection = connectionMaybe.Value;
ChatClient chat = aoia.GetChatClient(connection.Endpoint);
return chat;
});

return client;
}
}
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

using System.ComponentModel;

namespace Azure.Provisioning.CloudMachine;

public abstract class CloudMachineFeature
{
[EditorBrowsable(EditorBrowsableState.Never)]
public abstract void AddTo(CloudMachineInfrastructure cm);
}
Loading

0 comments on commit fabfa6c

Please sign in to comment.