Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove linq #1949

Merged
merged 11 commits into from
Mar 24, 2023
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
// See the LICENSE file in the project root for more information.

using System;
using System.Linq;
using System.Text;
using Azure.Core;
using Azure.Security.KeyVault.Keys.Cryptography;
Expand Down Expand Up @@ -229,14 +228,18 @@ byte[] DecryptEncryptionKey()
}

// Get ciphertext
byte[] cipherText = encryptedColumnEncryptionKey.Skip(currentIndex).Take(cipherTextLength).ToArray();
byte[] cipherText = new byte[cipherTextLength];
Array.Copy(encryptedColumnEncryptionKey, currentIndex, cipherText, 0, cipherTextLength);

currentIndex += cipherTextLength;

// Get signature
byte[] signature = encryptedColumnEncryptionKey.Skip(currentIndex).Take(signatureLength).ToArray();
byte[] signature = new byte[signatureLength];
Buffer.BlockCopy(encryptedColumnEncryptionKey, currentIndex, signature, 0, signatureLength);

// Compute the message to validate the signature
byte[] message = encryptedColumnEncryptionKey.Take(encryptedColumnEncryptionKey.Length - signatureLength).ToArray();
byte[] message = new byte[encryptedColumnEncryptionKey.Length - signatureLength];
Buffer.BlockCopy(encryptedColumnEncryptionKey, 0, message, 0, encryptedColumnEncryptionKey.Length - signatureLength);

if (null == message)
{
Expand Down Expand Up @@ -294,7 +297,24 @@ public override byte[] EncryptColumnEncryptionKey(string masterKeyPath, string e

// Compute message
// SHA-2-256(version + keyPathLength + ciphertextLength + keyPath + ciphertext)
byte[] message = s_firstVersion.Concat(keyPathLength).Concat(cipherTextLength).Concat(masterKeyPathBytes).Concat(cipherText).ToArray();
int messageLength = s_firstVersion.Length + keyPathLength.Length + cipherTextLength.Length + masterKeyPathBytes.Length + cipherText.Length;
byte[] message = new byte[messageLength];
int position = 0;

Buffer.BlockCopy(s_firstVersion, 0, message, position, s_firstVersion.Length);
position += s_firstVersion.Length;

Buffer.BlockCopy(keyPathLength, 0, message, position, keyPathLength.Length);
position += keyPathLength.Length;

Buffer.BlockCopy(cipherTextLength, 0, message, position, cipherTextLength.Length);
position += cipherTextLength.Length;

Buffer.BlockCopy(masterKeyPathBytes, 0, message, position, masterKeyPathBytes.Length);
position += masterKeyPathBytes.Length;

Buffer.BlockCopy(cipherText, 0, message, position, cipherText.Length);
position += cipherText.Length;

// Sign the message
byte[] signature = KeyCryptographer.SignData(message, masterKeyPath);
Expand All @@ -306,7 +326,11 @@ public override byte[] EncryptColumnEncryptionKey(string masterKeyPath, string e

ValidateSignature(masterKeyPath, message, signature);

return message.Concat(signature).ToArray();
byte[] retval = new byte[message.Length + signature.Length];
Buffer.BlockCopy(message, 0, retval, 0, message.Length);
Buffer.BlockCopy(signature, 0, retval, message.Length, signature.Length);

return retval;
}

#endregion
Expand Down Expand Up @@ -345,7 +369,7 @@ internal void ValidateNonEmptyAKVPath(string masterKeyPath, bool isSystemOp)

// Return an error indicating that the AKV url is invalid.
AKVEventSource.Log.TryTraceEvent("Master Key Path could not be validated as it does not end with trusted endpoints: {0}", masterKeyPath);
throw ADP.InvalidAKVUrlTrustedEndpoints(masterKeyPath, string.Join(", ", TrustedEndPoints.ToArray()));
throw ADP.InvalidAKVUrlTrustedEndpoints(masterKeyPath, string.Join(", ", TrustedEndPoints));
}

private void ValidateSignature(string masterKeyPath, byte[] message, byte[] signature)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
using System.Collections;
using System.Collections.Generic;
using System.Globalization;
using System.Linq;
using System.Security.Cryptography;

namespace Microsoft.Data.SqlClient.AlwaysEncrypted.AzureKeyVaultProvider
Expand All @@ -15,7 +14,7 @@ internal static class Validator
{
internal static void ValidateNotNull(object parameter, string name)
{
if (null == parameter)
if (parameter == null)
{
throw ADP.NullArgument(name);
}
Expand All @@ -31,9 +30,15 @@ internal static void ValidateNotEmpty(IList parameter, string name)

internal static void ValidateNotNullOrWhitespaceForEach(string[] parameters, string name)
{
if (parameters.Any(s => string.IsNullOrWhiteSpace(s)))
if (parameters != null && parameters.Length > 0)
{
throw ADP.NullOrWhitespaceForEach(name);
for (int index = 0; index < parameters.Length; index++)
{
if (string.IsNullOrWhiteSpace(parameters[index]))
{
throw ADP.NullOrWhitespaceForEach(name);
}
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
// See the LICENSE file in the project root for more information.

using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Reflection;
using System.Runtime.Loader;

Expand Down Expand Up @@ -69,7 +69,21 @@ private static Assembly AssemblyResolver(AssemblyName arg)
return fullPath == null ? null : AssemblyLoadContext.Default.LoadFromAssemblyPath(fullPath);
}

private static Type TypeResolver(Assembly arg1, string arg2, bool arg3) => arg1?.ExportedTypes.Single(t => t.FullName == arg2);
private static Type TypeResolver(Assembly arg1, string arg2, bool arg3)
{
IEnumerable<Type> types = arg1?.ExportedTypes;
if (types != null)
{
foreach (Type type in types)
{
if (type.FullName == arg2)
{
return type;
lcheunglci marked this conversation as resolved.
Show resolved Hide resolved
}
}
}
throw new InvalidOperationException("Sequence contains no matching element");
}

/// <summary>
/// Load assemblies on request.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
// See the LICENSE file in the project root for more information.

using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Threading;
using System.Linq;

namespace Microsoft.Data.SqlClient.SNI
{
Expand Down Expand Up @@ -83,12 +83,16 @@ public override void ReturnPacket(SNIPacket packet)
#if DEBUG
private string GetStackParts()
{
return string.Join(Environment.NewLine,
Environment.StackTrace
.Split(new string[] { Environment.NewLine }, StringSplitOptions.None)
.Skip(3) // trims off the common parts at the top of the stack so you can see what the actual caller was
.Take(7) // trims off most of the bottom of the stack because when running under xunit there's a lot of spam
);
// trims off the common parts at the top of the stack so you can see what the actual caller was
// trims off most of the bottom of the stack because when running under xunit there's a lot of spam
string[] parts = Environment.StackTrace.Split(new string[] { Environment.NewLine }, StringSplitOptions.None);
List<string> take = new List<string>(7);
for (int index = 3; take.Count < 7 && index < parts.Length; index++)
{
take.Add(parts[index]);
}

return string.Join(Environment.NewLine, take.ToArray());
}
#endif
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;
using System.Net;
using System.Net.Sockets;
using System.Text;
Expand Down Expand Up @@ -192,52 +191,77 @@ private static byte[] SendUDPRequest(string browserHostname, int port, byte[] re

IPAddress[] ipAddresses = SNICommon.GetDnsIpAddresses(browserHostname);
Debug.Assert(ipAddresses.Length > 0, "DNS should throw if zero addresses resolve");

IPAddress[] ipv4Addresses = null;
IPAddress[] ipv6Addresses = null;
switch (ipPreference)
{
case SqlConnectionIPAddressPreference.IPv4First:
{
SsrpResult response4 = SendUDPRequest(ipAddresses.Where(i => i.AddressFamily == AddressFamily.InterNetwork).ToArray(), port, requestPacket, allIPsInParallel);
SplitIPv4AndIPv6(ipv4Addresses, out ipv4Addresses, out ipv6Addresses);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
SplitIPv4AndIPv6(ipv4Addresses, out ipv4Addresses, out ipv6Addresses);
SplitIPv4AndIPv6(ipAddresses, out ipv4Addresses, out ipv6Addresses);

I'm assuming you're passing in the ipAddresses from line 192 instead of ipv4Addresses which is null on line 194; otherwise, it would always be Array.Empty<IPAddress>().

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch. It is concerning that the CI didn't fail on this error.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree. I was wondering about that as well. It turns out since this code is only run on Linux and requires sudo/administrative privileges to get the correct environment, there's currently no test coverage for this section unfortunately.


SsrpResult response4 = SendUDPRequest(ipv4Addresses, port, requestPacket, allIPsInParallel);
if (response4 != null && response4.ResponsePacket != null)
{
return response4.ResponsePacket;
}

SsrpResult response6 = SendUDPRequest(ipAddresses.Where(i => i.AddressFamily == AddressFamily.InterNetworkV6).ToArray(), port, requestPacket, allIPsInParallel);
SsrpResult response6 = SendUDPRequest(ipv6Addresses, port, requestPacket, allIPsInParallel);
if (response6 != null && response6.ResponsePacket != null)
{
return response6.ResponsePacket;
}

// No responses so throw first error
if (response4 != null && response4.Error != null)
{
throw response4.Error;
}
else if (response6 != null && response6.Error != null)
{
throw response6.Error;
}

break;
}
case SqlConnectionIPAddressPreference.IPv6First:
{
SsrpResult response6 = SendUDPRequest(ipAddresses.Where(i => i.AddressFamily == AddressFamily.InterNetworkV6).ToArray(), port, requestPacket, allIPsInParallel);
SplitIPv4AndIPv6(ipv4Addresses, out ipv4Addresses, out ipv6Addresses);
lcheunglci marked this conversation as resolved.
Show resolved Hide resolved

SsrpResult response6 = SendUDPRequest(ipv6Addresses, port, requestPacket, allIPsInParallel);
if (response6 != null && response6.ResponsePacket != null)
{
return response6.ResponsePacket;
}

SsrpResult response4 = SendUDPRequest(ipAddresses.Where(i => i.AddressFamily == AddressFamily.InterNetwork).ToArray(), port, requestPacket, allIPsInParallel);
SsrpResult response4 = SendUDPRequest(ipv4Addresses, port, requestPacket, allIPsInParallel);
if (response4 != null && response4.ResponsePacket != null)
{
return response4.ResponsePacket;
}

// No responses so throw first error
if (response6 != null && response6.Error != null)
{
throw response6.Error;
}
else if (response4 != null && response4.Error != null)
{
throw response4.Error;
}

break;
}
default:
{
SsrpResult response = SendUDPRequest(ipAddresses, port, requestPacket, true); // allIPsInParallel);
if (response != null && response.ResponsePacket != null)
{
return response.ResponsePacket;
}
else if (response != null && response.Error != null)
{
throw response.Error;
}

break;
}
Expand Down Expand Up @@ -372,5 +396,40 @@ internal static string SendBroadcastUDPRequest()
}
return response.ToString();
}

private static void SplitIPv4AndIPv6(IPAddress[] input, out IPAddress[] ipv4Addresses, out IPAddress[] ipv6Addresses)
{
ipv4Addresses = Array.Empty<IPAddress>();
ipv6Addresses = Array.Empty<IPAddress>();

if (input != null && input.Length > 0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe revert this condition and have fast exit? Code above looks a bit unnatural with this condition

ipv4Addresses = Array.Empty<IPAddress>();
ipv6Addresses = Array.Empty<IPAddress>();

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's like this because they're out parameters. I suppose it's a personal style choice but I always assign the default values for out parameters at the start of the method so that anywhere inside the body I can easily return without worrying about whether they're assigned or not.

{
List<IPAddress> v4 = new List<IPAddress>(1);
List<IPAddress> v6 = new List<IPAddress>(0);

for (int index = 0; index < input.Length; index++)
{
switch (input[index].AddressFamily)
{
case AddressFamily.InterNetwork:
v4.Add(input[index]);
break;
case AddressFamily.InterNetworkV6:
v6.Add(input[index]);
break;
}
}

if (v4.Count > 0)
{
ipv4Addresses = v4.ToArray();
}

if (v6.Count > 0)
{
ipv6Addresses = v6.ToArray();
}
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -43,25 +43,24 @@ static SqlAuthenticationProviderManager()
public SqlAuthenticationProviderManager(SqlAuthenticationProviderConfigurationSection configSection = null)
{
var methodName = "Ctor";
_typeName = GetType().Name;
_providers = new ConcurrentDictionary<SqlAuthenticationMethod, SqlAuthenticationProvider>();
var authenticationsWithAppSpecifiedProvider = new HashSet<SqlAuthenticationMethod>();
_authenticationsWithAppSpecifiedProvider = authenticationsWithAppSpecifiedProvider;

if (configSection == null)
{
_sqlAuthLogger.LogInfo(_typeName, methodName, "Neither SqlClientAuthenticationProviders nor SqlAuthenticationProviders configuration section found.");
_sqlAuthLogger.LogInfo(nameof(SqlAuthenticationProviderManager), methodName, "Neither SqlClientAuthenticationProviders nor SqlAuthenticationProviders configuration section found.");
return;
}

if (!string.IsNullOrEmpty(configSection.ApplicationClientId))
{
_applicationClientId = configSection.ApplicationClientId;
_sqlAuthLogger.LogInfo(_typeName, methodName, "Received user-defined Application Client Id");
_sqlAuthLogger.LogInfo(nameof(SqlAuthenticationProviderManager), methodName, "Received user-defined Application Client Id");
}
else
{
_sqlAuthLogger.LogInfo(_typeName, methodName, "No user-defined Application Client Id found.");
_sqlAuthLogger.LogInfo(nameof(SqlAuthenticationProviderManager), methodName, "No user-defined Application Client Id found.");
}

// Create user-defined auth initializer, if any.
Expand All @@ -77,11 +76,11 @@ public SqlAuthenticationProviderManager(SqlAuthenticationProviderConfigurationSe
{
throw SQL.CannotCreateSqlAuthInitializer(configSection.InitializerType, e);
}
_sqlAuthLogger.LogInfo(_typeName, methodName, "Created user-defined SqlAuthenticationInitializer.");
_sqlAuthLogger.LogInfo(nameof(SqlAuthenticationProviderManager), methodName, "Created user-defined SqlAuthenticationInitializer.");
}
else
{
_sqlAuthLogger.LogInfo(_typeName, methodName, "No user-defined SqlAuthenticationInitializer found.");
_sqlAuthLogger.LogInfo(nameof(SqlAuthenticationProviderManager), methodName, "No user-defined SqlAuthenticationInitializer found.");
}

// add user-defined providers, if any.
Expand All @@ -107,12 +106,12 @@ public SqlAuthenticationProviderManager(SqlAuthenticationProviderConfigurationSe

_providers[authentication] = provider;
authenticationsWithAppSpecifiedProvider.Add(authentication);
_sqlAuthLogger.LogInfo(_typeName, methodName, string.Format("Added user-defined auth provider: {0} for authentication {1}.", providerSettings?.Type, authentication));
_sqlAuthLogger.LogInfo(nameof(SqlAuthenticationProviderManager), methodName, string.Format("Added user-defined auth provider: {0} for authentication {1}.", providerSettings?.Type, authentication));
}
}
else
{
_sqlAuthLogger.LogInfo(_typeName, methodName, "No user-defined auth providers.");
_sqlAuthLogger.LogInfo(nameof(SqlAuthenticationProviderManager), methodName, "No user-defined auth providers.");
}
}

Expand Down
Loading