Skip to content

Commit

Permalink
Optimize Span<T>.Fill implementation (#51365)
Browse files Browse the repository at this point in the history
  • Loading branch information
GrabYourPitchforks committed Apr 17, 2021
1 parent e9f3e6a commit fbd3b98
Show file tree
Hide file tree
Showing 3 changed files with 302 additions and 42 deletions.
110 changes: 110 additions & 0 deletions src/libraries/System.Memory/tests/Span/Fill.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System.Linq;
using System.Runtime.InteropServices;
using Xunit;
using static System.TestHelpers;
Expand Down Expand Up @@ -147,5 +148,114 @@ public static unsafe void FillNativeBytes()
Marshal.FreeHGlobal(new IntPtr(ptr));
}
}

[Fact]
public static void FillWithRecognizedType()
{
RunTest<sbyte>(0x20);
RunTest<byte>(0x20);
RunTest<bool>(true);
RunTest<short>(0x1234);
RunTest<ushort>(0x1234);
RunTest<char>('x');
RunTest<int>(0x12345678);
RunTest<uint>(0x12345678);
RunTest<long>(0x0123456789abcdef);
RunTest<ulong>(0x0123456789abcdef);
RunTest<nint>(unchecked((nint)0x0123456789abcdef));
RunTest<nuint>(unchecked((nuint)0x0123456789abcdef));
RunTest<Half>((Half)1.0);
RunTest<float>(1.0f);
RunTest<double>(1.0);
RunTest<StringComparison>(StringComparison.CurrentCultureIgnoreCase); // should be treated as underlying primitive
RunTest<string>("Hello world!"); // ref type, no SIMD
RunTest<decimal>(1.0m); // 128-bit struct
RunTest<Guid>(new Guid("29e07627-2481-4f43-8fbf-09cf21180239")); // 128-bit struct
RunTest<My96BitStruct>(new(0x11111111, 0x22222222, 0x33333333)); // 96-bit struct, no SIMD
RunTest<My256BitStruct>(new(0x1111111111111111, 0x2222222222222222, 0x3333333333333333, 0x4444444444444444));
RunTest<My512BitStruct>(new(
0x1111111111111111, 0x2222222222222222, 0x3333333333333333, 0x4444444444444444,
0x5555555555555555, 0x6666666666666666, 0x7777777777777777, 0x8888888888888888)); // 512-bit struct, no SIMD
RunTest<MyRefContainingStruct>(new("Hello world!")); // struct contains refs, no SIMD

static void RunTest<T>(T value)
{
T[] arr = new T[128];

// Run tests for lengths := 0 to 64, ensuring we don't overrun our buffer

for (int i = 0; i <= 64; i++)
{
arr.AsSpan(0, i).Fill(value);
Assert.Equal(Enumerable.Repeat(value, i), arr.Take(i)); // first i entries should've been populated with 'value'
Assert.Equal(Enumerable.Repeat(default(T), arr.Length - i), arr.Skip(i)); // remaining entries should contain default(T)
Array.Clear(arr, 0, arr.Length);
}
}
}

private readonly struct My96BitStruct
{
public My96BitStruct(int data0, int data1, int data2)
{
Data0 = data0;
Data1 = data1;
Data2 = data2;
}

public readonly int Data0;
public readonly int Data1;
public readonly int Data2;
}

private readonly struct My256BitStruct
{
public My256BitStruct(ulong data0, ulong data1, ulong data2, ulong data3)
{
Data0 = data0;
Data1 = data1;
Data2 = data2;
Data3 = data3;
}

public readonly ulong Data0;
public readonly ulong Data1;
public readonly ulong Data2;
public readonly ulong Data3;
}

private readonly struct My512BitStruct
{
public My512BitStruct(ulong data0, ulong data1, ulong data2, ulong data3, ulong data4, ulong data5, ulong data6, ulong data7)
{
Data0 = data0;
Data1 = data1;
Data2 = data2;
Data3 = data3;
Data4 = data4;
Data5 = data5;
Data6 = data6;
Data7 = data7;
}

public readonly ulong Data0;
public readonly ulong Data1;
public readonly ulong Data2;
public readonly ulong Data3;
public readonly ulong Data4;
public readonly ulong Data5;
public readonly ulong Data6;
public readonly ulong Data7;
}

private readonly struct MyRefContainingStruct
{
public MyRefContainingStruct(object data)
{
Data = data;
}

public readonly object Data;
}
}
}
56 changes: 15 additions & 41 deletions src/libraries/System.Private.CoreLib/src/System/Span.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
using System.Runtime.Versioning;
using System.Text;
using EditorBrowsableAttribute = System.ComponentModel.EditorBrowsableAttribute;
using EditorBrowsableState = System.ComponentModel.EditorBrowsableState;
using Internal.Runtime.CompilerServices;
Expand Down Expand Up @@ -280,53 +279,28 @@ public unsafe void Clear()
/// <summary>
/// Fills the contents of this span with the given value.
/// </summary>
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public void Fill(T value)
{
if (Unsafe.SizeOf<T>() == 1)
{
uint length = (uint)_length;
if (length == 0)
return;

T tmp = value; // Avoid taking address of the "value" argument. It would regress performance of the loop below.
Unsafe.InitBlockUnaligned(ref Unsafe.As<T, byte>(ref _pointer.Value), Unsafe.As<T, byte>(ref tmp), length);
#if MONO
// Mono runtime's implementation of initblk performs a null check on the address.
// We'll perform a length check here to avoid passing a null address in the empty span case.
if (_length != 0)
#endif
{
// Special-case single-byte types like byte / sbyte / bool.
// The runtime eventually calls memset, which can efficiently support large buffers.
// We don't need to check IsReferenceOrContainsReferences because no references
// can ever be stored in types this small.
Unsafe.InitBlockUnaligned(ref Unsafe.As<T, byte>(ref _pointer.Value), Unsafe.As<T, byte>(ref value), (uint)_length);
}
}
else
{
// Do all math as nuint to avoid unnecessary 64->32->64 bit integer truncations
nuint length = (uint)_length;
if (length == 0)
return;

ref T r = ref _pointer.Value;

// TODO: Create block fill for value types of power of two sizes e.g. 2,4,8,16

nuint elementSize = (uint)Unsafe.SizeOf<T>();
nuint i = 0;
for (; i < (length & ~(nuint)7); i += 8)
{
Unsafe.AddByteOffset<T>(ref r, (i + 0) * elementSize) = value;
Unsafe.AddByteOffset<T>(ref r, (i + 1) * elementSize) = value;
Unsafe.AddByteOffset<T>(ref r, (i + 2) * elementSize) = value;
Unsafe.AddByteOffset<T>(ref r, (i + 3) * elementSize) = value;
Unsafe.AddByteOffset<T>(ref r, (i + 4) * elementSize) = value;
Unsafe.AddByteOffset<T>(ref r, (i + 5) * elementSize) = value;
Unsafe.AddByteOffset<T>(ref r, (i + 6) * elementSize) = value;
Unsafe.AddByteOffset<T>(ref r, (i + 7) * elementSize) = value;
}
if (i < (length & ~(nuint)3))
{
Unsafe.AddByteOffset<T>(ref r, (i + 0) * elementSize) = value;
Unsafe.AddByteOffset<T>(ref r, (i + 1) * elementSize) = value;
Unsafe.AddByteOffset<T>(ref r, (i + 2) * elementSize) = value;
Unsafe.AddByteOffset<T>(ref r, (i + 3) * elementSize) = value;
i += 4;
}
for (; i < length; i++)
{
Unsafe.AddByteOffset<T>(ref r, i * elementSize) = value;
}
// Call our optimized workhorse method for all other types.
SpanHelpers.Fill(ref _pointer.Value, (uint)_length, value);
}
}

Expand Down
178 changes: 177 additions & 1 deletion src/libraries/System.Private.CoreLib/src/System/SpanHelpers.T.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,189 @@
// The .NET Foundation licenses this file to you under the MIT license.

using System.Diagnostics;

using System.Numerics;
using System.Runtime.CompilerServices;
using System.Runtime.Intrinsics;
using Internal.Runtime.CompilerServices;

namespace System
{
internal static partial class SpanHelpers // .T
{
public static void Fill<T>(ref T refData, nuint numElements, T value)
{
// Early checks to see if it's even possible to vectorize - JIT will turn these checks into consts.
// - T cannot contain references (GC can't track references in vectors)
// - Vectorization must be hardware-accelerated
// - T's size must not exceed the vector's size
// - T's size must be a whole power of 2

if (RuntimeHelpers.IsReferenceOrContainsReferences<T>()) { goto CannotVectorize; }
if (!Vector.IsHardwareAccelerated) { goto CannotVectorize; }
if (Unsafe.SizeOf<T>() > Vector<byte>.Count) { goto CannotVectorize; }
if (!BitOperations.IsPow2(Unsafe.SizeOf<T>())) { goto CannotVectorize; }

if (numElements >= (uint)(Vector<byte>.Count / Unsafe.SizeOf<T>()))
{
// We have enough data for at least one vectorized write.

T tmp = value; // Avoid taking address of the "value" argument. It would regress performance of the loops below.
Vector<byte> vector;

if (Unsafe.SizeOf<T>() == 1)
{
vector = new Vector<byte>(Unsafe.As<T, byte>(ref tmp));
}
else if (Unsafe.SizeOf<T>() == 2)
{
vector = (Vector<byte>)(new Vector<ushort>(Unsafe.As<T, ushort>(ref tmp)));
}
else if (Unsafe.SizeOf<T>() == 4)
{
// special-case float since it's already passed in a SIMD reg
vector = (typeof(T) == typeof(float))
? (Vector<byte>)(new Vector<float>((float)(object)tmp!))
: (Vector<byte>)(new Vector<uint>(Unsafe.As<T, uint>(ref tmp)));
}
else if (Unsafe.SizeOf<T>() == 8)
{
// special-case double since it's already passed in a SIMD reg
vector = (typeof(T) == typeof(double))
? (Vector<byte>)(new Vector<double>((double)(object)tmp!))
: (Vector<byte>)(new Vector<ulong>(Unsafe.As<T, ulong>(ref tmp)));
}
else if (Unsafe.SizeOf<T>() == 16)
{
Vector128<byte> vec128 = Unsafe.As<T, Vector128<byte>>(ref tmp);
if (Vector<byte>.Count == 16)
{
vector = vec128.AsVector();
}
else if (Vector<byte>.Count == 32)
{
vector = Vector256.Create(vec128, vec128).AsVector();
}
else
{
Debug.Fail("Vector<T> isn't 128 or 256 bits in size?");
goto CannotVectorize;
}
}
else if (Unsafe.SizeOf<T>() == 32)
{
if (Vector<byte>.Count == 32)
{
vector = Unsafe.As<T, Vector256<byte>>(ref tmp).AsVector();
}
else
{
Debug.Fail("Vector<T> isn't 256 bits in size?");
goto CannotVectorize;
}
}
else
{
Debug.Fail("Vector<T> is greater than 256 bits in size?");
goto CannotVectorize;
}

ref byte refDataAsBytes = ref Unsafe.As<T, byte>(ref refData);
nuint totalByteLength = numElements * (nuint)Unsafe.SizeOf<T>(); // get this calculation ready ahead of time
nuint stopLoopAtOffset = totalByteLength & (nuint)(nint)(2 * (int)-Vector<byte>.Count); // intentional sign extension carries the negative bit
nuint offset = 0;

// Loop, writing 2 vectors at a time.
// Compare 'numElements' rather than 'stopLoopAtOffset' because we don't want a dependency
// on the very recently calculated 'stopLoopAtOffset' value.

if (numElements >= (uint)(2 * Vector<byte>.Count / Unsafe.SizeOf<T>()))
{
do
{
Unsafe.WriteUnaligned(ref Unsafe.AddByteOffset(ref refDataAsBytes, offset), vector);
Unsafe.WriteUnaligned(ref Unsafe.AddByteOffset(ref refDataAsBytes, offset + (nuint)Vector<byte>.Count), vector);
offset += (uint)(2 * Vector<byte>.Count);
} while (offset < stopLoopAtOffset);
}

// At this point, if any data remains to be written, it's strictly less than
// 2 * sizeof(Vector) bytes. The loop above had us write an even number of vectors.
// If the total byte length instead involves us writing an odd number of vectors, write
// one additional vector now. The bit check below tells us if we're in an "odd vector
// count" situation.

if ((totalByteLength & (nuint)Vector<byte>.Count) != 0)
{
Unsafe.WriteUnaligned(ref Unsafe.AddByteOffset(ref refDataAsBytes, offset), vector);
}

// It's possible that some small buffer remains to be populated - something that won't
// fit an entire vector's worth of data. Instead of falling back to a loop, we'll write
// a vector at the very end of the buffer. This may involve overwriting previously
// populated data, which is fine since we're splatting the same value for all entries.
// There's no need to perform a length check here because we already performed this
// check before entering the vectorized code path.

Unsafe.WriteUnaligned(ref Unsafe.AddByteOffset(ref refDataAsBytes, totalByteLength - (nuint)Vector<byte>.Count), vector);

// And we're done!

return;
}

CannotVectorize:

// If we reached this point, we cannot vectorize this T, or there are too few
// elements for us to vectorize. Fall back to an unrolled loop.

nuint i = 0;

// Write 8 elements at a time

if (numElements >= 8)
{
nuint stopLoopAtOffset = numElements & ~(nuint)7;
do
{
Unsafe.Add(ref refData, (nint)i + 0) = value;
Unsafe.Add(ref refData, (nint)i + 1) = value;
Unsafe.Add(ref refData, (nint)i + 2) = value;
Unsafe.Add(ref refData, (nint)i + 3) = value;
Unsafe.Add(ref refData, (nint)i + 4) = value;
Unsafe.Add(ref refData, (nint)i + 5) = value;
Unsafe.Add(ref refData, (nint)i + 6) = value;
Unsafe.Add(ref refData, (nint)i + 7) = value;
} while ((i += 8) < stopLoopAtOffset);
}

// Write next 4 elements if needed

if ((numElements & 4) != 0)
{
Unsafe.Add(ref refData, (nint)i + 0) = value;
Unsafe.Add(ref refData, (nint)i + 1) = value;
Unsafe.Add(ref refData, (nint)i + 2) = value;
Unsafe.Add(ref refData, (nint)i + 3) = value;
i += 4;
}

// Write next 2 elements if needed

if ((numElements & 2) != 0)
{
Unsafe.Add(ref refData, (nint)i + 0) = value;
Unsafe.Add(ref refData, (nint)i + 1) = value;
i += 2;
}

// Write final element if needed

if ((numElements & 1) != 0)
{
Unsafe.Add(ref refData, (nint)i) = value;
}
}

public static int IndexOf<T>(ref T searchSpace, int searchSpaceLength, ref T value, int valueLength) where T : IEquatable<T>
{
Debug.Assert(searchSpaceLength >= 0);
Expand Down

0 comments on commit fbd3b98

Please sign in to comment.