Skip to content

Commit

Permalink
Merge pull request #3 from johnnypham/akv-upgrade
Browse files Browse the repository at this point in the history
Enclave example code and updating some public API comments
  • Loading branch information
cheenamalhotra authored Feb 5, 2021
2 parents aea83c6 + 0c6c2f5 commit 52dc29d
Show file tree
Hide file tree
Showing 6 changed files with 329 additions and 19 deletions.
253 changes: 253 additions & 0 deletions doc/samples/AzureKeyVaultProviderWithEnclaveProviderExample_2_0.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,253 @@
using System;
//<Snippet1>
using System.Collections.Generic;
using System.Security.Cryptography;
using System.Threading.Tasks;
using Azure.Identity;
using Microsoft.Data.SqlClient;
using Microsoft.Data.SqlClient.AlwaysEncrypted.AzureKeyVaultProvider;

namespace AKVEnclaveExample
{
class Program
{
static readonly string s_algorithm = "RSA_OAEP";

// ********* Provide details here ***********
static readonly string s_akvUrl = "https://{KeyVaultName}.vault.azure.net/keys/{Key}/{KeyIdentifier}";
static readonly string s_clientId = "{Application_Client_ID}";
static readonly string s_clientSecret = "{Application_Client_Secret}";
static readonly string s_tenantId = "{Azure_Key_Vault_Active_Directory_Tenant_Id}";
static readonly string s_connectionString = "Server={Server}; Database={database}; Integrated Security=true; Column Encryption Setting=Enabled; Attestation Protocol=HGS; Enclave Attestation Url = {attestation_url_for_HGS};";
// ******************************************

static void Main(string[] args)
{
// Initialize AKV provider
ClientSecretCredential clientSecretCredential = new ClientSecretCredential(s_tenantId, s_clientId, s_clientSecret);
SqlColumnEncryptionAzureKeyVaultProvider akvProvider = new SqlColumnEncryptionAzureKeyVaultProvider(clientSecretCredential);

// Register AKV provider
SqlConnection.RegisterColumnEncryptionKeyStoreProviders(customProviders: new Dictionary<string, SqlColumnEncryptionKeyStoreProvider>(capacity: 1, comparer: StringComparer.OrdinalIgnoreCase)
{
{ SqlColumnEncryptionAzureKeyVaultProvider.ProviderName, akvProvider}
});
Console.WriteLine("AKV provider Registered");

// Create connection to database
using (SqlConnection sqlConnection = new SqlConnection(s_connectionString))
{
string cmkName = "CMK_WITH_AKV";
string cekName = "CEK_WITH_AKV";
string tblName = "AKV_TEST_TABLE";

CustomerRecord customer = new CustomerRecord(1, @"Microsoft", @"Corporation");

try
{
sqlConnection.Open();

// Drop Objects if exists
dropObjects(sqlConnection, cmkName, cekName, tblName);

// Create Column Master Key with AKV Url
createCMK(sqlConnection, cmkName, akvProvider);
Console.WriteLine("Column Master Key created.");

// Create Column Encryption Key
createCEK(sqlConnection, cmkName, cekName, akvProvider);
Console.WriteLine("Column Encryption Key created.");

// Create Table with Encrypted Columns
createTbl(sqlConnection, cekName, tblName);
Console.WriteLine("Table created with Encrypted columns.");

// Insert Customer Record in table
insertData(sqlConnection, tblName, customer);
Console.WriteLine("Encryted data inserted.");

// Read data from table
verifyData(sqlConnection, tblName, customer);
Console.WriteLine("Data validated successfully.");
}
finally
{
// Drop table and keys
dropObjects(sqlConnection, cmkName, cekName, tblName);
Console.WriteLine("Dropped Table, CEK and CMK");
}

Console.WriteLine("Completed AKV provider Sample.");

Console.ReadKey();
}
}

private static void createCMK(SqlConnection sqlConnection, string cmkName, SqlColumnEncryptionAzureKeyVaultProvider sqlColumnEncryptionAzureKeyVaultProvider)
{
string KeyStoreProviderName = SqlColumnEncryptionAzureKeyVaultProvider.ProviderName;

byte[] cmkSign = sqlColumnEncryptionAzureKeyVaultProvider.SignColumnMasterKeyMetadata(s_akvUrl, true);
string cmkSignStr = string.Concat("0x", BitConverter.ToString(cmkSign).Replace("-", string.Empty));

string sql =
$@"CREATE COLUMN MASTER KEY [{cmkName}]
WITH (
KEY_STORE_PROVIDER_NAME = N'{KeyStoreProviderName}',
KEY_PATH = N'{s_akvUrl}',
ENCLAVE_COMPUTATIONS (SIGNATURE = {cmkSignStr})
);";

using (SqlCommand command = sqlConnection.CreateCommand())
{
command.CommandText = sql;
command.ExecuteNonQuery();
}
}

private static void createCEK(SqlConnection sqlConnection, string cmkName, string cekName, SqlColumnEncryptionAzureKeyVaultProvider sqlColumnEncryptionAzureKeyVaultProvider)
{
string sql =
$@"CREATE COLUMN ENCRYPTION KEY [{cekName}]
WITH VALUES (
COLUMN_MASTER_KEY = [{cmkName}],
ALGORITHM = '{s_algorithm}',
ENCRYPTED_VALUE = {GetEncryptedValue(sqlColumnEncryptionAzureKeyVaultProvider)}
)";

using (SqlCommand command = sqlConnection.CreateCommand())
{
command.CommandText = sql;
command.ExecuteNonQuery();
}
}

private static string GetEncryptedValue(SqlColumnEncryptionAzureKeyVaultProvider sqlColumnEncryptionAzureKeyVaultProvider)
{
byte[] plainTextColumnEncryptionKey = new byte[32];
RNGCryptoServiceProvider rngCsp = new RNGCryptoServiceProvider();
rngCsp.GetBytes(plainTextColumnEncryptionKey);

byte[] encryptedColumnEncryptionKey = sqlColumnEncryptionAzureKeyVaultProvider.EncryptColumnEncryptionKey(s_akvUrl, s_algorithm, plainTextColumnEncryptionKey);
string EncryptedValue = string.Concat("0x", BitConverter.ToString(encryptedColumnEncryptionKey).Replace("-", string.Empty));
return EncryptedValue;
}

private static void createTbl(SqlConnection sqlConnection, string cekName, string tblName)
{
string ColumnEncryptionAlgorithmName = @"AEAD_AES_256_CBC_HMAC_SHA_256";

string sql =
$@"CREATE TABLE [dbo].[{tblName}]
(
[CustomerId] [int] ENCRYPTED WITH (COLUMN_ENCRYPTION_KEY = [{cekName}], ENCRYPTION_TYPE = RANDOMIZED, ALGORITHM = '{ColumnEncryptionAlgorithmName}'),
[FirstName] [nvarchar](50) COLLATE Latin1_General_BIN2 ENCRYPTED WITH (COLUMN_ENCRYPTION_KEY = [{cekName}], ENCRYPTION_TYPE = RANDOMIZED, ALGORITHM = '{ColumnEncryptionAlgorithmName}'),
[LastName] [nvarchar](50) COLLATE Latin1_General_BIN2 ENCRYPTED WITH (COLUMN_ENCRYPTION_KEY = [{cekName}], ENCRYPTION_TYPE = RANDOMIZED, ALGORITHM = '{ColumnEncryptionAlgorithmName}')
)";

using (SqlCommand command = sqlConnection.CreateCommand())
{
command.CommandText = sql;
command.ExecuteNonQuery();
}
}

private static void insertData(SqlConnection sqlConnection, string tblName, CustomerRecord customer)
{
string insertSql = $"INSERT INTO [{tblName}] (CustomerId, FirstName, LastName) VALUES (@CustomerId, @FirstName, @LastName);";

using (SqlTransaction sqlTransaction = sqlConnection.BeginTransaction())
using (SqlCommand sqlCommand = new SqlCommand(insertSql,
connection: sqlConnection, transaction: sqlTransaction,
columnEncryptionSetting: SqlCommandColumnEncryptionSetting.Enabled))
{
sqlCommand.Parameters.AddWithValue(@"CustomerId", customer.Id);
sqlCommand.Parameters.AddWithValue(@"FirstName", customer.FirstName);
sqlCommand.Parameters.AddWithValue(@"LastName", customer.LastName);

sqlCommand.ExecuteNonQuery();
sqlTransaction.Commit();
}
}

private static void verifyData(SqlConnection sqlConnection, string tblName, CustomerRecord customer)
{
// Test INPUT parameter on an encrypted parameter
using (SqlCommand sqlCommand = new SqlCommand($"SELECT CustomerId, FirstName, LastName FROM [{tblName}] WHERE FirstName = @firstName",
sqlConnection))
{
SqlParameter customerFirstParam = sqlCommand.Parameters.AddWithValue(@"firstName", @"Microsoft");
customerFirstParam.Direction = System.Data.ParameterDirection.Input;
customerFirstParam.ForceColumnEncryption = true;

using (SqlDataReader sqlDataReader = sqlCommand.ExecuteReader())
{
ValidateResultSet(sqlDataReader);
}
}
}

private static void ValidateResultSet(SqlDataReader sqlDataReader)
{
Console.WriteLine(" * Row available: " + sqlDataReader.HasRows);

while (sqlDataReader.Read())
{
if (sqlDataReader.GetInt32(0) == 1)
{
Console.WriteLine(" * Employee Id received as sent: " + sqlDataReader.GetInt32(0));
}
else
{
Console.WriteLine("Employee Id didn't match");
}

if (sqlDataReader.GetString(1) == @"Microsoft")
{
Console.WriteLine(" * Employee Firstname received as sent: " + sqlDataReader.GetString(1));
}
else
{
Console.WriteLine("Employee FirstName didn't match.");
}

if (sqlDataReader.GetString(2) == @"Corporation")
{
Console.WriteLine(" * Employee LastName received as sent: " + sqlDataReader.GetString(2));
}
else
{
Console.WriteLine("Employee LastName didn't match.");
}
}
}

private static void dropObjects(SqlConnection sqlConnection, string cmkName, string cekName, string tblName)
{
using (SqlCommand cmd = sqlConnection.CreateCommand())
{
cmd.CommandText = $@"IF EXISTS (select * from sys.objects where name = '{tblName}') BEGIN DROP TABLE [{tblName}] END";
cmd.ExecuteNonQuery();
cmd.CommandText = $@"IF EXISTS (select * from sys.column_encryption_keys where name = '{cekName}') BEGIN DROP COLUMN ENCRYPTION KEY [{cekName}] END";
cmd.ExecuteNonQuery();
cmd.CommandText = $@"IF EXISTS (select * from sys.column_master_keys where name = '{cmkName}') BEGIN DROP COLUMN MASTER KEY [{cmkName}] END";
cmd.ExecuteNonQuery();
}
}

private class CustomerRecord
{
internal int Id { get; set; }
internal string FirstName { get; set; }
internal string LastName { get; set; }

public CustomerRecord(int id, string fName, string lName)
{
Id = id;
FirstName = fName;
LastName = lName;
}
}
}
}
//</Snippet1>
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,9 @@ internal void AddKey(string keyIdentifierUri)
{
if (TheKeyHasNotBeenCached(keyIdentifierUri))
{
ParseAKVPath(keyIdentifierUri, out Uri vaultUri, out string keyName);
ParseAKVPath(keyIdentifierUri, out Uri vaultUri, out string keyName, out string keyVersion);
CreateKeyClient(vaultUri);
FetchKey(vaultUri, keyName, keyIdentifierUri);
FetchKey(vaultUri, keyName, keyVersion, keyIdentifierUri);
}

bool TheKeyHasNotBeenCached(string k) => !_keyDictionary.ContainsKey(k) && !_keyFetchTaskDictionary.ContainsKey(k);
Expand Down Expand Up @@ -151,10 +151,11 @@ private CryptographyClient GetCryptographyClient(string keyIdentifierUri)
/// </summary>
/// <param name="vaultUri">The Azure Key Vault URI</param>
/// <param name="keyName">The name of the Azure Key Vault key</param>
/// <param name="keyVersion">The version of the Azure Key Vault key</param>
/// <param name="keyResourceUri">The Azure Key Vault key identifier</param>
private void FetchKey(Uri vaultUri, string keyName, string keyResourceUri)
private void FetchKey(Uri vaultUri, string keyName, string keyVersion, string keyResourceUri)
{
Task<Azure.Response<KeyVaultKey>> fetchKeyTask = FetchKeyFromKeyVault(vaultUri, keyName);
Task<Azure.Response<KeyVaultKey>> fetchKeyTask = FetchKeyFromKeyVault(vaultUri, keyName, keyVersion);
_keyFetchTaskDictionary.AddOrUpdate(keyResourceUri, fetchKeyTask, (k, v) => fetchKeyTask);

fetchKeyTask
Expand All @@ -169,11 +170,12 @@ private void FetchKey(Uri vaultUri, string keyName, string keyResourceUri)
/// </summary>
/// <param name="vaultUri">The Azure Key Vault URI</param>
/// <param name="keyName">Then name of the key</param>
/// <param name="keyVersion">Then version of the key</param>
/// <returns></returns>
private Task<Azure.Response<KeyVaultKey>> FetchKeyFromKeyVault(Uri vaultUri, string keyName)
private Task<Azure.Response<KeyVaultKey>> FetchKeyFromKeyVault(Uri vaultUri, string keyName, string keyVersion)
{
_keyClientDictionary.TryGetValue(vaultUri, out KeyClient keyClient);
return keyClient.GetKeyAsync(keyName);
return keyClient.GetKeyAsync(keyName, keyVersion);
}

/// <summary>
Expand Down Expand Up @@ -209,11 +211,13 @@ private void CreateKeyClient(Uri vaultUri)
/// <param name="masterKeyPath">The Azure Key Vault key identifier</param>
/// <param name="vaultUri">The Azure Key Vault URI</param>
/// <param name="masterKeyName">The name of the key</param>
private void ParseAKVPath(string masterKeyPath, out Uri vaultUri, out string masterKeyName)
/// <param name="masterKeyVersion">The version of the key</param>
private void ParseAKVPath(string masterKeyPath, out Uri vaultUri, out string masterKeyName, out string masterKeyVersion)
{
Uri masterKeyPathUri = new Uri(masterKeyPath);
vaultUri = new Uri(masterKeyPathUri.GetLeftPart(UriPartial.Authority));
masterKeyName = masterKeyPathUri.Segments[2];
masterKeyVersion = masterKeyPathUri.Segments.Length > 3 ? masterKeyPathUri.Segments[3] : null;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ namespace Microsoft.Data.SqlClient.AlwaysEncrypted.AzureKeyVaultProvider
/// <format type="text/markdown"><![CDATA[
/// ## Remarks
///
/// **SqlColumnEncryptionAzureKeyVaultProvider** is implemented for Microsoft.Data.SqlClient driver and supports .NET Framework 4.6+ and .NET Core 2.1+.
/// **SqlColumnEncryptionAzureKeyVaultProvider** is implemented for Microsoft.Data.SqlClient driver and supports .NET Framework 4.6.1+ and .NET Core 2.1+.
/// The provider name identifier for this implementation is "AZURE_KEY_VAULT" and it is not registered in driver by default.
/// Client applications must call <xref=Microsoft.Data.SqlClient.SqlConnection.RegisterColumnEncryptionKeyStoreProviders> API only once in the lifetime of driver to register this custom provider by implementing a custom Authentication Callback mechanism.
///
Expand Down Expand Up @@ -117,10 +117,10 @@ public SqlColumnEncryptionAzureKeyVaultProvider(TokenCredential tokenCredential,
#region Public methods

/// <summary>
/// Uses an asymmetric key identified by the key path to sign the masterkey metadata consisting of (masterKeyPath, allowEnclaveComputations bit, providerName).
/// Uses an asymmetric key identified by the key path to sign the master key metadata consisting of (masterKeyPath, allowEnclaveComputations bit, providerName).
/// </summary>
/// <param name="masterKeyPath">Complete path of an asymmetric key. Path format is specific to a key store provider.</param>
/// <param name="allowEnclaveComputations">Boolean indicating whether this key can be sent to trusted enclave</param>
/// <param name="allowEnclaveComputations">Boolean indicating whether this key can be sent to a trusted enclave</param>
/// <returns>Encrypted column encryption key</returns>
public override byte[] SignColumnMasterKeyMetadata(string masterKeyPath, bool allowEnclaveComputations)
{
Expand All @@ -133,7 +133,7 @@ public override byte[] SignColumnMasterKeyMetadata(string masterKeyPath, bool al
}

/// <summary>
/// Uses an asymmetric key identified by the key path to verify the masterkey metadata consisting of (masterKeyPath, allowEnclaveComputations bit, providerName).
/// Uses an asymmetric key identified by the key path to verify the master key metadata consisting of (masterKeyPath, allowEnclaveComputations bit, providerName).
/// </summary>
/// <param name="masterKeyPath">Complete path of an asymmetric key. Path format is specific to a key store provider.</param>
/// <param name="allowEnclaveComputations">Boolean indicating whether this key can be sent to trusted enclave</param>
Expand All @@ -153,7 +153,7 @@ public override bool VerifyColumnMasterKeyMetadata(string masterKeyPath, bool al
/// This function uses the asymmetric key specified by the key path
/// and decrypts an encrypted CEK with RSA encryption algorithm.
/// </summary>
/// <param name="masterKeyPath">Complete path of an asymmetric key in AKV</param>
/// <param name="masterKeyPath">Complete path of an asymmetric key in Azure Key Vault</param>
/// <param name="encryptionAlgorithm">Asymmetric Key Encryption Algorithm</param>
/// <param name="encryptedColumnEncryptionKey">Encrypted Column Encryption Key</param>
/// <returns>Plain text column encryption key</returns>
Expand Down Expand Up @@ -234,7 +234,7 @@ public override byte[] DecryptColumnEncryptionKey(string masterKeyPath, string e
/// This function uses the asymmetric key specified by the key path
/// and encrypts CEK with RSA encryption algorithm.
/// </summary>
/// <param name="masterKeyPath">Complete path of an asymmetric key in AKV</param>
/// <param name="masterKeyPath">Complete path of an asymmetric key in Azure Key Vault</param>
/// <param name="encryptionAlgorithm">Asymmetric Key Encryption Algorithm</param>
/// <param name="columnEncryptionKey">Plain text column encryption key</param>
/// <returns>Encrypted column encryption key</returns>
Expand All @@ -253,7 +253,7 @@ public override byte[] EncryptColumnEncryptionKey(string masterKeyPath, string e

// Construct the encryptedColumnEncryptionKey
// Format is
// s_firstVersion + keyPathLength + ciphertextLength + ciphertext + keyPath + signature
// s_firstVersion + keyPathLength + ciphertextLength + keyPath + ciphertext + signature

// Get the Unicode encoded bytes of cultureinvariant lower case masterKeyPath
byte[] masterKeyPathBytes = Encoding.Unicode.GetBytes(masterKeyPath.ToLowerInvariant());
Expand Down Expand Up @@ -310,8 +310,7 @@ internal void ValidateNonEmptyAKVPath(string masterKeyPath, bool isSystemOp)
throw new ArgumentException(errorMessage, Constants.AeParamMasterKeyPath);
}


if (!Uri.TryCreate(masterKeyPath, UriKind.Absolute, out Uri parsedUri))
if (!Uri.TryCreate(masterKeyPath, UriKind.Absolute, out Uri parsedUri) || parsedUri.Segments.Length < 3)
{
// Return an error indicating that the AKV url is invalid.
throw new ArgumentException(string.Format(CultureInfo.InvariantCulture, Strings.InvalidAkvUrlTemplate, masterKeyPath), Constants.AeParamMasterKeyPath);
Expand Down
Loading

0 comments on commit 52dc29d

Please sign in to comment.