Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Light up core ASCII.CaseConversion methods with Vector256/Vector512 code paths #88923

Closed
wants to merge 8 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,8 @@ private static unsafe nuint ChangeCase<TFrom, TTo, TCasing>(TFrom* pSrc, TTo* pD
bool conversionIsWidthPreserving = typeof(TFrom) == typeof(TTo); // JIT turns this into a const
bool conversionIsToUpper = (typeof(TCasing) == typeof(ToUpperConversion)); // JIT turns this into a const
uint numInputElementsToConsumeEachVectorizedLoopIteration = (uint)(sizeof(Vector128<byte>) / sizeof(TFrom)); // JIT turns this into a const
uint numInputElementsToConsumeEachVectorizedLoopIteration_256 = (uint)(sizeof(Vector256<byte>) / sizeof(TFrom)); // JIT turns this into a const
uint numInputElementsToConsumeEachVectorizedLoopIteration_512 = (uint)(sizeof(Vector512<byte>) / sizeof(TFrom)); // JIT turns this into a const

nuint i = 0;

Expand All @@ -231,9 +233,154 @@ private static unsafe nuint ChangeCase<TFrom, TTo, TCasing>(TFrom* pSrc, TTo* pD
goto DrainRemaining;
}

// Process the input as a series of 128-bit blocks.
// Process the input as a series of 512-bit blocks.
if (Vector512.IsHardwareAccelerated && elementCount >= numInputElementsToConsumeEachVectorizedLoopIteration_512)
{
Vector512<TFrom> srcVector = Vector512.LoadUnsafe(ref *pSrc);
if (VectorContainsNonAsciiChar(srcVector))
{
goto Drain64;
}

// Now find matching characters and perform case conversion.
// Basically, the (A <= value && value <= Z) check is converted to:
// (value - CONST) <= (Z - A), but using signed instead of unsigned arithmetic.

TFrom SourceSignedMinValue = TFrom.CreateTruncating(1 << (8 * sizeof(TFrom) - 1));
Vector512<TFrom> subtractionVector = Vector512.Create(conversionIsToUpper ? (SourceSignedMinValue + TFrom.CreateTruncating('a')) : (SourceSignedMinValue + TFrom.CreateTruncating('A')));
Vector512<TFrom> comparisionVector = Vector512.Create(SourceSignedMinValue + TFrom.CreateTruncating(26 /* A..Z or a..z */));
Vector512<TFrom> caseConversionVector = Vector512.Create(TFrom.CreateTruncating(0x20)); // works both directions

Vector512<TFrom> matches = SignedLessThan((srcVector - subtractionVector), comparisionVector);
srcVector ^= (matches & caseConversionVector);

// Now write to the destination.

ChangeWidthAndWriteTo(srcVector, pDest, 0);

// Now that the first conversion is out of the way, calculate how
// many elements we should skip in order to have future writes be
// aligned.

uint expectedWriteAlignment_512 = numInputElementsToConsumeEachVectorizedLoopIteration_512 * (uint)sizeof(TTo); // JIT turns this into a const
i = numInputElementsToConsumeEachVectorizedLoopIteration_512 - ((uint)pDest % expectedWriteAlignment_512) / (uint)sizeof(TTo);
Debug.Assert((nuint)(&pDest[i]) % expectedWriteAlignment_512 == 0, "Destination buffer wasn't properly aligned!");

while (true)
{
Debug.Assert(i <= elementCount, "We overran a buffer somewhere.");

if ((elementCount - i) < numInputElementsToConsumeEachVectorizedLoopIteration_512)
{
// If we're about to enter the final iteration of the loop, back up so that
// we can read one unaligned block. If we've already consumed all the data,
// jump straight to the end.

if (i == elementCount)
{
goto Return;
}

i = elementCount - numInputElementsToConsumeEachVectorizedLoopIteration_512;
}

// Unaligned read & check for non-ASCII data.

srcVector = Vector512.LoadUnsafe(ref *pSrc, i);
if (VectorContainsNonAsciiChar(srcVector))
{
goto Drain64;
}

// Now find matching characters and perform case conversion.

matches = SignedLessThan((srcVector - subtractionVector), comparisionVector);
srcVector ^= (matches & caseConversionVector);

// Now write to the destination.
// We expect this write to be aligned except for the last run through the loop.

if (Vector128.IsHardwareAccelerated && elementCount >= numInputElementsToConsumeEachVectorizedLoopIteration)
ChangeWidthAndWriteTo(srcVector, pDest, i);
i += numInputElementsToConsumeEachVectorizedLoopIteration_512;
}
}
// Process the input as a series of 256-bit blocks.
else if (Vector256.IsHardwareAccelerated && elementCount >= numInputElementsToConsumeEachVectorizedLoopIteration_256)
{
// Unaligned read and check for non-ASCII data.
Vector256<TFrom> srcVector = Vector256.LoadUnsafe(ref *pSrc);
if (VectorContainsNonAsciiChar(srcVector))
{
goto Drain64;
}

// Now find matching characters and perform case conversion.
// Basically, the (A <= value && value <= Z) check is converted to:
// (value - CONST) <= (Z - A), but using signed instead of unsigned arithmetic.

TFrom SourceSignedMinValue = TFrom.CreateTruncating(1 << (8 * sizeof(TFrom) - 1));
Vector256<TFrom> subtractionVector = Vector256.Create(conversionIsToUpper ? (SourceSignedMinValue + TFrom.CreateTruncating('a')) : (SourceSignedMinValue + TFrom.CreateTruncating('A')));
Vector256<TFrom> comparisionVector = Vector256.Create(SourceSignedMinValue + TFrom.CreateTruncating(26 /* A..Z or a..z */));
Vector256<TFrom> caseConversionVector = Vector256.Create(TFrom.CreateTruncating(0x20)); // works both directions

Vector256<TFrom> matches = SignedLessThan((srcVector - subtractionVector), comparisionVector);
srcVector ^= (matches & caseConversionVector);

// Now write to the destination.

ChangeWidthAndWriteTo(srcVector, pDest, 0);

// Now that the first conversion is out of the way, calculate how
// many elements we should skip in order to have future writes be
// aligned.

uint expectedWriteAlignment_256 = numInputElementsToConsumeEachVectorizedLoopIteration_256 * (uint)sizeof(TTo); // JIT turns this into a const
i = numInputElementsToConsumeEachVectorizedLoopIteration_256 - ((uint)pDest % expectedWriteAlignment_256) / (uint)sizeof(TTo);
Debug.Assert((nuint)(&pDest[i]) % expectedWriteAlignment_256 == 0, "Destination buffer wasn't properly aligned!");

// Future iterations of this loop will be aligned,
// except for the last iteration.

while (true)
{
Debug.Assert(i <= elementCount, "We overran a buffer somewhere.");

if ((elementCount - i) < numInputElementsToConsumeEachVectorizedLoopIteration_256)
{
// If we're about to enter the final iteration of the loop, back up so that
// we can read one unaligned block. If we've already consumed all the data,
// jump straight to the end.

if (i == elementCount)
{
goto Return;
}

i = elementCount - numInputElementsToConsumeEachVectorizedLoopIteration_256;
}

// Unaligned read & check for non-ASCII data.

srcVector = Vector256.LoadUnsafe(ref *pSrc, i);
if (VectorContainsNonAsciiChar(srcVector))
{
goto Drain64;
}

// Now find matching characters and perform case conversion.

matches = SignedLessThan((srcVector - subtractionVector), comparisionVector);
srcVector ^= (matches & caseConversionVector);

// Now write to the destination.
// We expect this write to be aligned except for the last run through the loop.

ChangeWidthAndWriteTo(srcVector, pDest, i);
i += numInputElementsToConsumeEachVectorizedLoopIteration_256;
}
}
// Process the input as a series of 128-bit blocks.
else if (Vector128.IsHardwareAccelerated && elementCount >= numInputElementsToConsumeEachVectorizedLoopIteration)
{
// Unaligned read and check for non-ASCII data.

Expand Down Expand Up @@ -500,6 +647,71 @@ private static unsafe void ChangeWidthAndWriteTo<TFrom, TTo>(Vector128<TFrom> ve
}
}

[MethodImpl(MethodImplOptions.AggressiveInlining)]
private static unsafe void ChangeWidthAndWriteTo<TFrom, TTo>(Vector256<TFrom> vector, TTo* pDest, nuint elementOffset)
where TFrom : unmanaged
where TTo : unmanaged
{
if (sizeof(TFrom) == sizeof(TTo))
{
// no width change needed
Vector256.StoreUnsafe(vector.As<TFrom, TTo>(), ref *pDest, elementOffset);
}
else if (sizeof(TFrom) == 1 && sizeof(TTo) == 2)
{
if (Vector512.IsHardwareAccelerated)
{
Vector512<ushort> wide = Vector512.WidenLower(vector.AsByte().ToVector512Unsafe());
Vector512.StoreUnsafe(wide, ref *(ushort*)pDest, elementOffset);
}
else
{
Vector256.StoreUnsafe(Vector256.WidenLower(vector.AsByte()), ref *(ushort*)pDest, elementOffset);
Vector256.StoreUnsafe(Vector256.WidenUpper(vector.AsByte()), ref *(ushort*)pDest, elementOffset + 16);
}
}
else if (sizeof(TFrom) == 2 && sizeof(TTo) == 1)
{
// narrowing operation required, we know data is all-ASCII so use extract helper
Vector256<byte> narrow = Vector256.Narrow(vector.AsUInt16(), vector.AsUInt16());
narrow.GetLower().StoreUnsafe(ref *(byte*)pDest, elementOffset);
}
else
{
Debug.Fail("Unknown types.");
throw new NotSupportedException();
}
}

[MethodImpl(MethodImplOptions.AggressiveInlining)]
private static unsafe void ChangeWidthAndWriteTo<TFrom, TTo>(Vector512<TFrom> vector, TTo* pDest, nuint elementOffset)
where TFrom : unmanaged
where TTo : unmanaged
{
if (sizeof(TFrom) == sizeof(TTo))
{
// no width change needed
Vector512.StoreUnsafe(vector.As<TFrom, TTo>(), ref *pDest, elementOffset);
}
else if (sizeof(TFrom) == 1 && sizeof(TTo) == 2)
{
// widening operation required
Vector512.StoreUnsafe(Vector512.WidenLower(vector.AsByte()), ref *(ushort*)pDest, elementOffset);
Vector512.StoreUnsafe(Vector512.WidenUpper(vector.AsByte()), ref *(ushort*)pDest, elementOffset + 32);
}
else if (sizeof(TFrom) == 2 && sizeof(TTo) == 1)
{
// narrowing operation required, we know data is all-ASCII so use extract helper
Vector512<byte> narrow = Vector512.Narrow(vector.AsUInt16(), vector.AsUInt16());
narrow.GetLower().StoreUnsafe(ref *(byte*)pDest, elementOffset);
}
else
{
Debug.Fail("Unknown types.");
throw new NotSupportedException();
}
}

[MethodImpl(MethodImplOptions.AggressiveInlining)]
private static unsafe Vector128<T> SignedLessThan<T>(Vector128<T> left, Vector128<T> right)
where T : unmanaged
Expand All @@ -518,6 +730,42 @@ private static unsafe Vector128<T> SignedLessThan<T>(Vector128<T> left, Vector12
}
}

[MethodImpl(MethodImplOptions.AggressiveInlining)]
private static unsafe Vector256<T> SignedLessThan<T>(Vector256<T> left, Vector256<T> right)
where T : unmanaged
{
if (sizeof(T) == 1)
{
return Vector256.LessThan(left.AsSByte(), right.AsSByte()).As<sbyte, T>();
}
else if (sizeof(T) == 2)
{
return Vector256.LessThan(left.AsInt16(), right.AsInt16()).As<short, T>();
}
else
{
throw new NotSupportedException();
}
}

[MethodImpl(MethodImplOptions.AggressiveInlining)]
private static unsafe Vector512<T> SignedLessThan<T>(Vector512<T> left, Vector512<T> right)
where T : unmanaged
{
if (sizeof(T) == 1)
{
return Vector512.LessThan(left.AsSByte(), right.AsSByte()).As<sbyte, T>();
}
else if (sizeof(T) == 2)
{
return Vector512.LessThan(left.AsInt16(), right.AsInt16()).As<short, T>();
}
else
{
throw new NotSupportedException();
}
}

private struct ToUpperConversion { }
private struct ToLowerConversion { }
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1455,6 +1455,72 @@ private static bool VectorContainsNonAsciiChar<T>(Vector128<T> vector)
: VectorContainsNonAsciiChar(vector.AsUInt16());
}

[MethodImpl(MethodImplOptions.AggressiveInlining)]
private static bool VectorContainsNonAsciiChar(Vector256<byte> asciiVector)
{
// max ASCII character is 0b_0111_1111, so the most significant bit (0x80) tells whether it contains non ascii

// prefer architecture specific intrinsic as they offer better perf
return asciiVector.ExtractMostSignificantBits() != 0;
}

[MethodImpl(MethodImplOptions.AggressiveInlining)]
private static bool VectorContainsNonAsciiChar(Vector256<ushort> utf16Vector)
{
if (Avx.IsSupported)
{
Vector256<ushort> asciiMaskForTestZ = Vector256.Create((ushort)0xFF80);
return !Avx.TestZ(utf16Vector.AsInt16(), asciiMaskForTestZ.AsInt16());
}
else
{
const ushort asciiMask = ushort.MaxValue - 127; // 0xFF80
Vector256<ushort> zeroIsAscii = utf16Vector & Vector256.Create(asciiMask);
// If a non-ASCII bit is set in any WORD of the vector, we have seen non-ASCII data.
return zeroIsAscii != Vector256<ushort>.Zero;
}
}

[MethodImpl(MethodImplOptions.AggressiveInlining)]
private static bool VectorContainsNonAsciiChar<T>(Vector256<T> vector)
where T : unmanaged
{
Debug.Assert(typeof(T) == typeof(byte) || typeof(T) == typeof(ushort));

return typeof(T) == typeof(byte)
? VectorContainsNonAsciiChar(vector.AsByte())
: VectorContainsNonAsciiChar(vector.AsUInt16());
}

[MethodImpl(MethodImplOptions.AggressiveInlining)]
private static bool VectorContainsNonAsciiChar(Vector512<byte> asciiVector)
{
// max ASCII character is 0b_0111_1111, so the most significant bit (0x80) tells whether it contains non ascii

// prefer architecture specific intrinsic as they offer better perf
return asciiVector.ExtractMostSignificantBits() != 0;
}

[MethodImpl(MethodImplOptions.AggressiveInlining)]
private static bool VectorContainsNonAsciiChar(Vector512<ushort> utf16Vector)
{
const ushort asciiMask = ushort.MaxValue - 127; // 0xFF80
Vector512<ushort> zeroIsAscii = utf16Vector & Vector512.Create(asciiMask);
// If a non-ASCII bit is set in any WORD of the vector, we have seen non-ASCII data.
return zeroIsAscii != Vector512<ushort>.Zero;
}

[MethodImpl(MethodImplOptions.AggressiveInlining)]
private static bool VectorContainsNonAsciiChar<T>(Vector512<T> vector)
where T : unmanaged
{
Debug.Assert(typeof(T) == typeof(byte) || typeof(T) == typeof(ushort));

return typeof(T) == typeof(byte)
? VectorContainsNonAsciiChar(vector.AsByte())
: VectorContainsNonAsciiChar(vector.AsUInt16());
}

[MethodImpl(MethodImplOptions.AggressiveInlining)]
private static bool AllCharsInVectorAreAscii<T>(Vector128<T> vector)
where T : unmanaged
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,10 @@ public static IEnumerable<object[]> MultipleValidCharacterConversion_Arguments
yield return new object[] { "\0xyz\0", "\0xyz\0", "\0XYZ\0" };
yield return new object[] { "\0XYZ\0", "\0xyz\0", "\0XYZ\0" };
yield return new object[] { "AbCdEFgHIJkLmNoPQRStUVwXyZ", "abcdefghijklmnopqrstuvwxyz", "ABCDEFGHIJKLMNOPQRSTUVWXYZ" };
yield return new object[] { "AbCdEFgHIJkLmNoPQRStUVwXyZasdWSq", "abcdefghijklmnopqrstuvwxyzasdwsq", "ABCDEFGHIJKLMNOPQRSTUVWXYZASDWSQ" }; // Test case for Vector256 path.
yield return new object[] { "AbCdEFgHIJkLmNoPQRStUVwXyZasdWSqAbCdEFgHIJkLmNoPQRStUVwXyZasdWSq",
"abcdefghijklmnopqrstuvwxyzasdwsqabcdefghijklmnopqrstuvwxyzasdwsq",
"ABCDEFGHIJKLMNOPQRSTUVWXYZASDWSQABCDEFGHIJKLMNOPQRSTUVWXYZASDWSQ" }; // Test case for Vector512 path.

// exercise all possible code paths
for (int i = 1; i <= MaxValidAsciiChar; i++)
Expand All @@ -180,7 +184,6 @@ public static void MultipleValidCharacterConversion(string sourceChars, string e
{
Assert.Equal(sourceChars.Length, expectedLowerChars.Length);
Assert.Equal(expectedLowerChars.Length, expectedUpperChars.Length);

byte[] sourceBytes = Encoding.ASCII.GetBytes(sourceChars);
byte[] expectedLowerBytes = Encoding.ASCII.GetBytes(expectedLowerChars);
byte[] expectedUpperBytes = Encoding.ASCII.GetBytes(expectedUpperChars);
Expand Down
Loading