Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement System.Buffers.Text.Base64.DecodeFromUtf8 for Arm64 #70336

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
138 changes: 138 additions & 0 deletions src/libraries/System.Memory/src/System/Buffers/Text/Base64Decoder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,14 @@
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
using System.Runtime.Intrinsics;
using System.Runtime.Intrinsics.Arm;
using System.Runtime.Intrinsics.X86;

namespace System.Buffers.Text
{
// AVX2 version based on https://github.com/aklomp/base64/tree/e516d769a2a432c08404f1981e73b431566057be/lib/arch/avx2
// SSSE3 version based on https://github.com/aklomp/base64/tree/e516d769a2a432c08404f1981e73b431566057be/lib/arch/ssse3
// AdvSimd version based on https://github.com/aklomp/base64/blob/e516d769a2a432c08404f1981e73b431566057be/lib/arch/neon64

public static partial class Base64
{
Expand Down Expand Up @@ -81,6 +83,15 @@ public static unsafe OperationStatus DecodeFromUtf8(ReadOnlySpan<byte> utf8, Spa
if (src == srcEnd)
goto DoneExit;
}

end = srcMax - 96;
if (BitConverter.IsLittleEndian && AdvSimd.Arm64.IsSupported && (end >= src))
{
AdvSimdDecode(ref src, ref dest, end, maxSrcLength, destLength, srcBytes, destBytes);

if (src == srcEnd)
goto DoneExit;
}
}

// Last bytes could have padding characters, so process them separately and treat them as valid only if isFinalBlock is true
Expand Down Expand Up @@ -644,6 +655,133 @@ private static unsafe void Ssse3Decode(ref byte* srcBytes, ref byte* destBytes,
destBytes = dest;
}

[MethodImpl(MethodImplOptions.AggressiveInlining)]
private static unsafe Vector128<byte> AdvSimdTbx8Byte(Vector128<byte> defaults, Vector128<byte> table0, Vector128<byte> table1, Vector128<byte> table2, Vector128<byte> table3, Vector128<byte> table4, Vector128<byte> table5, Vector128<byte> table6, Vector128<byte> table7, Vector128<byte> indicies, Vector128<byte> offset)
{
// Implement an 8 way table lookup.
// This could be reduced by using two NEON TBX4 instructions.

Debug.Assert(AdvSimd.Arm64.IsSupported && BitConverter.IsLittleEndian);

Vector128<byte> dest = defaults;
Vector128<byte> indicies_sub = indicies;

dest = AdvSimd.Arm64.VectorTableLookupExtension(dest, table0, indicies_sub);
indicies_sub = AdvSimd.Subtract(indicies_sub, offset);
dest = AdvSimd.Arm64.VectorTableLookupExtension(dest, table1, indicies_sub);
indicies_sub = AdvSimd.Subtract(indicies_sub, offset);
dest = AdvSimd.Arm64.VectorTableLookupExtension(dest, table2, indicies_sub);
indicies_sub = AdvSimd.Subtract(indicies_sub, offset);
dest = AdvSimd.Arm64.VectorTableLookupExtension(dest, table3, indicies_sub);
indicies_sub = AdvSimd.Subtract(indicies_sub, offset);
dest = AdvSimd.Arm64.VectorTableLookupExtension(dest, table4, indicies_sub);
indicies_sub = AdvSimd.Subtract(indicies_sub, offset);
dest = AdvSimd.Arm64.VectorTableLookupExtension(dest, table5, indicies_sub);
indicies_sub = AdvSimd.Subtract(indicies_sub, offset);
dest = AdvSimd.Arm64.VectorTableLookupExtension(dest, table6, indicies_sub);
indicies_sub = AdvSimd.Subtract(indicies_sub, offset);
dest = AdvSimd.Arm64.VectorTableLookupExtension(dest, table7, indicies_sub);
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking at this using perf, a lot of time is spent in this function. One reason for that is due to the chain of dependencies.

Splitting the indicies_sub into separate variables didn't make any noticeable difference:
var indicies1 = AdvSimd.Subtract(indicies, Vector128.Create((byte)(16U1)));
var indicies2 = AdvSimd.Subtract(indicies, Vector128.Create((byte)(16U
2)));
var indicies3 = AdvSimd.Subtract(indicies, Vector128.Create((byte)(16U*3)));
etc

The TBXs could be split out:
increment everything in the lookup table so that it starts from 1. Do the lookups using TBLs, meaning failures are 0. Combine all the results with ORs. Then subtract 1 from the result.
That would add more complexity, and I very much doubt it's going to give much benefit overall.


return dest;
}

[MethodImpl(MethodImplOptions.AggressiveInlining)]
private static unsafe Vector128<byte> AdvSimdTbx3Byte(Vector128<byte> defaults, Vector128<byte> table0, Vector128<byte> table1, Vector128<byte> table2, Vector128<byte> indicies, Vector128<byte> offset)
{
// Implement a 3 way table lookup.

Debug.Assert(AdvSimd.Arm64.IsSupported && BitConverter.IsLittleEndian);

Vector128<byte> dest = defaults;
Vector128<byte> indicies_sub = indicies;

dest = AdvSimd.Arm64.VectorTableLookupExtension(dest, table0, indicies_sub);
indicies_sub = AdvSimd.Subtract(indicies_sub, offset);
dest = AdvSimd.Arm64.VectorTableLookupExtension(dest, table1, indicies_sub);
indicies_sub = AdvSimd.Subtract(indicies_sub, offset);
dest = AdvSimd.Arm64.VectorTableLookupExtension(dest, table2, indicies_sub);

return dest;
}

[MethodImpl(MethodImplOptions.AggressiveInlining)]
private static unsafe void AdvSimdDecode(ref byte* srcBytes, ref byte* destBytes, byte* srcEnd, int sourceLength, int destLength, byte* srcStart, byte* destStart)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just curious - does AggressiveInlining here show benefits in the benchmarks?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I remove this one it doesn't make a difference.
If I remove the two around AdvSimdTbx3Byte and AdvSimdTbx8Byte it gets 2x worse.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

those definitely make sense while this one is a bit questionable

{
Debug.Assert(AdvSimd.Arm64.IsSupported && BitConverter.IsLittleEndian);

// Complete lookup table - similar to that used in the SS3 decode.
Vector128<byte> dec_lut0 = Vector128.Create(255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255);
Vector128<byte> dec_lut1 = Vector128.Create(255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255);
Vector128<byte> dec_lut2 = Vector128.Create(255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 62, 255, 255, 255, 63);
Vector128<byte> dec_lut3 = Vector128.Create( 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 255, 255, 255, 255, 255, 255);
Vector128<byte> dec_lut4 = Vector128.Create(255, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14);
Vector128<byte> dec_lut5 = Vector128.Create( 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 255, 255, 255, 255, 255);
Vector128<byte> dec_lut6 = Vector128.Create(255, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40);
Vector128<byte> dec_lut7 = Vector128.Create( 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 255, 255, 255, 255, 255);

// Interleave pattern for the ST3.
Vector128<byte> st3_interleave_index0 = Vector128.Create((byte) 0, 16, 32, 1, 17, 33, 2, 18, 34, 3, 19, 35, 4, 20, 36, 5);
Vector128<byte> st3_interleave_index1 = Vector128.Create((byte)21, 37, 6, 22, 38, 7, 23, 39, 8, 24, 40, 9, 25, 41, 10, 26);
Vector128<byte> st3_interleave_index2 = Vector128.Create((byte)42, 11, 27, 43, 12, 28, 44, 13, 29, 45, 14, 30, 46, 15, 31, 47);

// Some constants.
Vector128<byte> vzero = Vector128.Create((byte)0);
Vector128<byte> v255 = Vector128.Create((byte)255U);
Vector128<byte> v16 = Vector128.Create((byte)16U);

byte* src = srcBytes;
byte* dest = destBytes;

do
{
// Load 64 bytes of data and deinterleave the result.
// This is equivalent to a NEON LD4 instruction.
Vector128<byte> str0 = Vector128.LoadUnsafe(ref *src);
Vector128<byte> str1 = Vector128.LoadUnsafe(ref *src, 16);
Vector128<byte> str2 = Vector128.LoadUnsafe(ref *src, 32);
Vector128<byte> str3 = Vector128.LoadUnsafe(ref *src, 48);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suspect you can use AdvSimd.Arm64.LoadPairVector128 (and even nontemporal if it makes sense) here - jit is not smart enough yet to do it by itself

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doing this worked, but didn't give any improvement

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suspect you can use AdvSimd.Arm64.LoadPairVector128 (and even nontemporal if it makes sense) here - jit is not smart enough yet to do it by itself

Vector128<short> tmp0 = AdvSimd.Arm64.UnzipEven(str0.AsInt16(), str1.AsInt16());
Vector128<short> tmp1 = AdvSimd.Arm64.UnzipOdd(str0.AsInt16(), str1.AsInt16());
Vector128<short> tmp2 = AdvSimd.Arm64.UnzipEven(str2.AsInt16(), str3.AsInt16());
Vector128<short> tmp3 = AdvSimd.Arm64.UnzipOdd(str2.AsInt16(), str3.AsInt16());
str0 = AdvSimd.Arm64.UnzipEven(tmp0.AsByte(), tmp2.AsByte());
str1 = AdvSimd.Arm64.UnzipOdd(tmp0.AsByte(), tmp2.AsByte());
str2 = AdvSimd.Arm64.UnzipEven(tmp1.AsByte(), tmp3.AsByte());
str3 = AdvSimd.Arm64.UnzipOdd(tmp1.AsByte(), tmp3.AsByte());

// Table lookup on each 16 bytes.
str0 = AdvSimdTbx8Byte(v255, dec_lut0, dec_lut1, dec_lut2, dec_lut3, dec_lut4, dec_lut5, dec_lut6, dec_lut7, str0, v16);
str1 = AdvSimdTbx8Byte(v255, dec_lut0, dec_lut1, dec_lut2, dec_lut3, dec_lut4, dec_lut5, dec_lut6, dec_lut7, str1, v16);
str2 = AdvSimdTbx8Byte(v255, dec_lut0, dec_lut1, dec_lut2, dec_lut3, dec_lut4, dec_lut5, dec_lut6, dec_lut7, str2, v16);
str3 = AdvSimdTbx8Byte(v255, dec_lut0, dec_lut1, dec_lut2, dec_lut3, dec_lut4, dec_lut5, dec_lut6, dec_lut7, str3, v16);

// Check for invalid input, any value larger than 63.
Vector128<byte> classified0 = AdvSimd.Arm64.MaxPairwise(str0, str1);
Vector128<byte> classified1 = AdvSimd.Arm64.MaxPairwise(str2, str3);
Vector128<byte> maxChars = AdvSimd.Arm64.MaxPairwise(classified0, classified1);
if ((maxChars.AsUInt64().ToScalar() & 0xc0c0c0c0c0c0c0c0) != 0)
break;

// Compress each four bytes into three.
Vector128<byte> dec0 = Vector128.BitwiseOr(Vector128.ShiftLeft(str0, 2), Vector128.ShiftRightLogical(str1, 4));
Vector128<byte> dec1 = Vector128.BitwiseOr(Vector128.ShiftLeft(str1, 4), Vector128.ShiftRightLogical(str2, 2));
Vector128<byte> dec2 = Vector128.BitwiseOr(Vector128.ShiftLeft(str2, 6), str3);
Copy link
Member

@EgorBo EgorBo Jun 7, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

jit doesn't do instruction selection so if you want your shifts to be side-by-side for better pipelining you need to extract them to temp locals, e.g.:

var sl0 = Vector128.ShiftLeft(str0, 2);
var sl1 = Vector128.ShiftLeft(str1, 4);
var sl2 = Vector128.ShiftLeft(str2, 6);

var sr1 = Vector128.ShiftRightLogical(str1, 4);
var sr2 = Vector128.ShiftRightLogical(str2, 2);

Vector128<byte> dec0 = Vector128.BitwiseOr(sl0, sr1);
Vector128<byte> dec1 = Vector128.BitwiseOr(sl1, sr2);
Vector128<byte> dec2 = Vector128.BitwiseOr(sl2, str3);

not sure it matters much in terms of perf

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copying that verbatim didn't make any difference.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another way to do this might be to treat the vector register as a single value and do something fancy with a masks and a single value shift. Not quite sure how that would look. Might get a some performance, but it'd be messy.


// Interleave the decoded result and store out.
// This is equivalent to a NEON ST3 instruction.
AdvSimdTbx3Byte(vzero, dec0, dec1, dec2, st3_interleave_index0, v16).Store(dest);
AdvSimdTbx3Byte(vzero, dec0, dec1, dec2, st3_interleave_index1, v16).Store(dest + 16);
AdvSimdTbx3Byte(vzero, dec0, dec1, dec2, st3_interleave_index2, v16).Store(dest + 32);

src += 64;
dest += 48;
}
while (src <= srcEnd);

srcBytes = src;
destBytes = dest;
}

[MethodImpl(MethodImplOptions.AggressiveInlining)]
private static unsafe int Decode(byte* encodedBytes, ref sbyte decodingMap)
{
Expand Down