Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Manually depad RSAES-PKCS1 on Apple OSes #97738

Merged
merged 7 commits into from
Feb 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// The .NET Foundation licenses this file to you under the MIT license.

using System;
using System.Buffers;
using System.Diagnostics;
using System.Runtime.InteropServices;
using System.Security.Cryptography;
Expand Down Expand Up @@ -69,8 +70,8 @@ private static partial int RsaDecryptOaep(
out SafeCFDataHandle pEncryptedOut,
out SafeCFErrorHandle pErrorOut);

[LibraryImport(Libraries.AppleCryptoNative, EntryPoint = "AppleCryptoNative_RsaDecryptPkcs")]
private static partial int RsaDecryptPkcs(
[LibraryImport(Libraries.AppleCryptoNative, EntryPoint = "AppleCryptoNative_RsaDecryptRaw")]
vcsjones marked this conversation as resolved.
Show resolved Hide resolved
private static partial int RsaDecryptRaw(
SafeSecKeyRefHandle publicKey,
ReadOnlySpan<byte> pbData,
int cbData,
Expand Down Expand Up @@ -166,17 +167,40 @@ internal static byte[] RsaDecrypt(
byte[] data,
RSAEncryptionPadding padding)
{
if (padding == RSAEncryptionPadding.Pkcs1)
{
byte[] padded = ExecuteTransform(
data,
(ReadOnlySpan<byte> source, out SafeCFDataHandle decrypted, out SafeCFErrorHandle error) =>
RsaDecryptRaw(privateKey, source, source.Length, out decrypted, out error));

byte[] depad = CryptoPool.Rent(padded.Length);
OperationStatus status = RsaPaddingProcessor.DepadPkcs1Encryption(padded, depad, out int written);
byte[]? ret = null;

if (status == OperationStatus.Done)
{
ret = depad.AsSpan(0, written).ToArray();
}

// Clear the whole thing, especially on failure.
CryptoPool.Return(depad);
CryptographicOperations.ZeroMemory(padded);

if (ret is null)
{
throw new CryptographicException(SR.Cryptography_InvalidPadding);
}

return ret;
}

Debug.Assert(padding.Mode == RSAEncryptionPaddingMode.Oaep);

return ExecuteTransform(
data,
(ReadOnlySpan<byte> source, out SafeCFDataHandle decrypted, out SafeCFErrorHandle error) =>
{
if (padding == RSAEncryptionPadding.Pkcs1)
{
return RsaDecryptPkcs(privateKey, source, source.Length, out decrypted, out error);
}

Debug.Assert(padding.Mode == RSAEncryptionPaddingMode.Oaep);

return RsaDecryptOaep(
privateKey,
source,
Expand All @@ -195,14 +219,63 @@ internal static bool TryRsaDecrypt(
out int bytesWritten)
{
Debug.Assert(padding.Mode == RSAEncryptionPaddingMode.Pkcs1 || padding.Mode == RSAEncryptionPaddingMode.Oaep);

if (padding.Mode == RSAEncryptionPaddingMode.Pkcs1)
{
byte[] padded = CryptoPool.Rent(source.Length);
byte[] depad = CryptoPool.Rent(source.Length);

bool processed = TryExecuteTransform(
source,
padded,
out int paddedLength,
(ReadOnlySpan<byte> innerSource, out SafeCFDataHandle outputHandle, out SafeCFErrorHandle errorHandle) =>
RsaDecryptRaw(privateKey, innerSource, innerSource.Length, out outputHandle, out errorHandle));

Debug.Assert(
processed,
"TryExecuteTransform should always return true for a large enough buffer.");

OperationStatus status = OperationStatus.InvalidData;
int depaddedLength = 0;

if (processed)
{
status = RsaPaddingProcessor.DepadPkcs1Encryption(
new ReadOnlySpan<byte>(padded, 0, paddedLength),
depad,
out depaddedLength);
}

CryptoPool.Return(padded);

if (status == OperationStatus.Done)
{
if (depaddedLength <= destination.Length)
{
depad.AsSpan(0, depaddedLength).CopyTo(destination);
CryptoPool.Return(depad);
bytesWritten = depaddedLength;
return true;
}

CryptoPool.Return(depad);
bytesWritten = 0;
return false;
}

CryptoPool.Return(depad);
Debug.Assert(status == OperationStatus.InvalidData);
throw new CryptographicException(SR.Cryptography_InvalidPadding);
}

return TryExecuteTransform(
source,
destination,
out bytesWritten,
delegate (ReadOnlySpan<byte> innerSource, out SafeCFDataHandle outputHandle, out SafeCFErrorHandle errorHandle)
{
return padding.Mode == RSAEncryptionPaddingMode.Pkcs1 ?
RsaDecryptPkcs(privateKey, innerSource, innerSource.Length, out outputHandle, out errorHandle) :
return
RsaDecryptOaep(privateKey, innerSource, innerSource.Length, PalAlgorithmFromAlgorithmName(padding.OaepHashAlgorithm), out outputHandle, out errorHandle);
});
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System.Buffers;
using System.Buffers.Binary;
using System.Collections.Concurrent;
using System.Diagnostics;
Expand Down Expand Up @@ -142,6 +143,109 @@ internal static void PadPkcs1Encryption(
source.CopyTo(mInEM);
}

internal static OperationStatus DepadPkcs1Encryption(
ReadOnlySpan<byte> source,
Span<byte> destination,
out int bytesWritten)
{
int primitive = DepadPkcs1Encryption(source);
int primitiveSign = SignStretch(primitive);

// Primitive is a positive length, or ~length to indicate
// an error, so flip ~length to length if the high bit is set.
int len = Choose(primitiveSign, ~primitive, primitive);
int spaceRemain = destination.Length - len;
vcsjones marked this conversation as resolved.
Show resolved Hide resolved
int spaceRemainSign = SignStretch(spaceRemain);

// len = clampHigh(len, destination.Length);
len = Choose(spaceRemainSign, destination.Length, len);

// ret = spaceRemain < 0 ? DestinationTooSmall : Done
int ret = Choose(
spaceRemainSign,
(int)OperationStatus.DestinationTooSmall,
(int)OperationStatus.Done);

// ret = primitive < 0 ? InvalidData : ret;
ret = Choose(primitiveSign, (int)OperationStatus.InvalidData, ret);

// Write some number of bytes, regardless of the final return.
source[^len..].CopyTo(destination);

// bytesWritten = ret == Done ? len : 0;
bytesWritten = Choose(CheckZero(ret), len, 0);
return (OperationStatus)ret;
}

private static int DepadPkcs1Encryption(ReadOnlySpan<byte> source)
{
Debug.Assert(source.Length > 11);
ReadOnlySpan<byte> afterPadding = source.Slice(10);
ReadOnlySpan<byte> noZeros = source.Slice(2, 8);

// Find the first zero in noZeros, or -1 for no zeros.
int zeroPos = BlindFindFirstZero(noZeros);

// If zeroPos is negative, valid is -1, otherwise 0.
int valid = SignStretch(zeroPos);

// If there are no zeros in afterPadding then zeroPos is negative,
// so negating the sign stretch is 0, which makes hasPos 0.
// If there -was- a zero, sign stretching is 0, so negating it makes hasPos -1.
zeroPos = BlindFindFirstZero(afterPadding);
int hasLen = ~SignStretch(zeroPos);
valid &= hasLen;

// Check that the first two bytes are { 00 02 }
valid &= CheckZero(source[0] | (source[1] ^ 0x02));

int lenIfGood = afterPadding.Length - zeroPos - 1;
// If there were no zeros, use the full after-min-padding segment.
int lenIfBad = ~Choose(hasLen, lenIfGood, source.Length - 11);

Debug.Assert(lenIfBad < 0);
return Choose(valid, lenIfGood, lenIfBad);
}

private static int BlindFindFirstZero(ReadOnlySpan<byte> source)
{
// Any vectorization of this routine needs to use non-early termination,
// and instructions that do not vary their completion time on the input.

int pos = -1;

for (int i = source.Length - 1; i >= 0; i--)
{
// pos = source[i] == 0 ? i : pos;
int local = CheckZero(source[i]);
pos = Choose(local, i, pos);
}

return pos;
}

private static int SignStretch(int value)
{
return value >> 31;
}

private static int Choose(int selector, int yes, int no)
{
Debug.Assert((selector | (selector - 1)) == -1);
return (selector & yes) | (~selector & no);
}

private static int CheckZero(int value)
{
// For zero, ~value and value-1 are both all bits set (negative).
// For positive values, ~value is negative and value-1 is positive.
// For negative values except MinValue, ~value is positive and value-1 is negative.
// For MinValue, ~value is positive and value-1 is also positive.
// All together, the only thing that has negative & negative is 0, so stretch the sign bit.
int mask = ~value & (value - 1);
return SignStretch(mask);
}

internal static void PadPkcs1Signature(
HashAlgorithmName hashAlgorithmName,
ReadOnlySpan<byte> source,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
// The .NET Foundation licenses this file to you under the MIT license.

using System.Collections.Generic;
using System.Diagnostics;
using System.Numerics;
using Test.Cryptography;
using Microsoft.DotNet.XUnitExtensions;
using Xunit;
Expand Down Expand Up @@ -736,6 +738,119 @@ public void Decrypt_Pkcs1_ErrorsForInvalidPadding(byte[] data)
}
}

[Fact]
public void Decrypt_Pkcs1_BadPadding()
{
if ((PlatformDetection.IsWindows && !PlatformDetection.IsWindows10Version2004OrGreater))
{
return;
}

RSAParameters keyParams = TestData.RSA2048Params;
BigInteger e = new BigInteger(keyParams.Exponent, true, true);
BigInteger n = new BigInteger(keyParams.Modulus, true, true);
byte[] buf = new byte[keyParams.Modulus.Length];
byte[] c = new byte[buf.Length];

buf[1] = 2;
buf.AsSpan(2).Fill(1);

ref byte afterMinPadding = ref buf[10];
ref byte lastByte = ref buf[^1];
afterMinPadding = 0;

using (RSA rsa = RSAFactory.Create(keyParams))
{
RawEncrypt(buf, e, n, c);
// Assert.NoThrow, check that manual padding is coherent
Decrypt(rsa, c, RSAEncryptionPadding.Pkcs1);

// All RSA encryption schemes start with 00, so pick any other number.
//
// If buf > modulus then encrypt should fail, so this
// is the largest legal-but-invalid value to test.
buf[0] = keyParams.Modulus[0];
RawEncrypt(buf, e, n, c);
Assert.ThrowsAny<CryptographicException>(() => Decrypt(rsa, c, RSAEncryptionPadding.Pkcs1));

// Check again with a zero length payload
(afterMinPadding, lastByte) = (lastByte, afterMinPadding);
RawEncrypt(buf, e, n, c);
Assert.ThrowsAny<CryptographicException>(() => Decrypt(rsa, c, RSAEncryptionPadding.Pkcs1));

// Back to valid padding
buf[0] = 0;
(afterMinPadding, lastByte) = (lastByte, afterMinPadding);
RawEncrypt(buf, e, n, c);
Decrypt(rsa, c, RSAEncryptionPadding.Pkcs1);

// This is (sort of) legal for PKCS1 signatures, but not decryption.
buf[1] = 1;
RawEncrypt(buf, e, n, c);
Assert.ThrowsAny<CryptographicException>(() => Decrypt(rsa, c, RSAEncryptionPadding.Pkcs1));

// No RSA PKCS1 padding scheme starts with 00 FF.
buf[1] = 255;
RawEncrypt(buf, e, n, c);
Assert.ThrowsAny<CryptographicException>(() => Decrypt(rsa, c, RSAEncryptionPadding.Pkcs1));

// Check again with a zero length payload
(afterMinPadding, lastByte) = (lastByte, afterMinPadding);
RawEncrypt(buf, e, n, c);
Assert.ThrowsAny<CryptographicException>(() => Decrypt(rsa, c, RSAEncryptionPadding.Pkcs1));

// Back to valid padding
buf[1] = 2;
(afterMinPadding, lastByte) = (lastByte, afterMinPadding);
RawEncrypt(buf, e, n, c);
Decrypt(rsa, c, RSAEncryptionPadding.Pkcs1);

// Try a zero in every possible required padding position
for (int i = 2; i < 10; i++)
{
buf[i] = 0;

RawEncrypt(buf, e, n, c);
Assert.ThrowsAny<CryptographicException>(() => Decrypt(rsa, c, RSAEncryptionPadding.Pkcs1));

// It used to be 1, now it's 2, still not zero.
buf[i] = 2;
}

// Back to valid padding
RawEncrypt(buf, e, n, c);
Decrypt(rsa, c, RSAEncryptionPadding.Pkcs1);

// Make it such that
// "there is no octet with hexadecimal value 0x00 to separate PS from M"
// (RFC 3447 sec 7.2.2, rule 3, third clause)
buf.AsSpan(10).Fill(3);
RawEncrypt(buf, e, n, c);
Assert.ThrowsAny<CryptographicException>(() => Decrypt(rsa, c, RSAEncryptionPadding.Pkcs1));

// Every possible problem, for good measure.
buf[0] = 2;
buf[1] = 0;
buf[4] = 0;
RawEncrypt(buf, e, n, c);
Assert.ThrowsAny<CryptographicException>(() => Decrypt(rsa, c, RSAEncryptionPadding.Pkcs1));
}

static void RawEncrypt(ReadOnlySpan<byte> source, BigInteger e, BigInteger n, Span<byte> destination)
{
BigInteger m = new BigInteger(source, true, true);
BigInteger c = BigInteger.ModPow(m, e, n);
int shift = destination.Length - c.GetByteCount(true);
destination.Slice(0, shift).Clear();
bool wrote = c.TryWriteBytes(destination.Slice(shift), out int written, true, true);

if (!wrote || written + shift != destination.Length)
{
throw new UnreachableException();
}
}
}

public static IEnumerable<object[]> OaepPaddingModes
{
get
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ static const Entry s_cryptoAppleNative[] =
DllImportEntry(AppleCryptoNative_RsaGenerateKey)
DllImportEntry(AppleCryptoNative_RsaDecryptOaep)
DllImportEntry(AppleCryptoNative_RsaDecryptPkcs)
DllImportEntry(AppleCryptoNative_RsaDecryptRaw)
DllImportEntry(AppleCryptoNative_RsaEncryptOaep)
DllImportEntry(AppleCryptoNative_RsaEncryptPkcs)
DllImportEntry(AppleCryptoNative_RsaSignaturePrimitive)
Expand Down
Loading
Loading