From fbd3b980d9842b72b00885c1cc17347424802944 Mon Sep 17 00:00:00 2001 From: Levi Broderick Date: Sat, 17 Apr 2021 01:29:50 -0700 Subject: [PATCH] Optimize Span.Fill implementation (#51365) --- .../System.Memory/tests/Span/Fill.cs | 110 +++++++++++ .../System.Private.CoreLib/src/System/Span.cs | 56 ++---- .../src/System/SpanHelpers.T.cs | 178 +++++++++++++++++- 3 files changed, 302 insertions(+), 42 deletions(-) diff --git a/src/libraries/System.Memory/tests/Span/Fill.cs b/src/libraries/System.Memory/tests/Span/Fill.cs index 558291da9b4e4..c6741219aad93 100644 --- a/src/libraries/System.Memory/tests/Span/Fill.cs +++ b/src/libraries/System.Memory/tests/Span/Fill.cs @@ -1,6 +1,7 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. +using System.Linq; using System.Runtime.InteropServices; using Xunit; using static System.TestHelpers; @@ -147,5 +148,114 @@ public static unsafe void FillNativeBytes() Marshal.FreeHGlobal(new IntPtr(ptr)); } } + + [Fact] + public static void FillWithRecognizedType() + { + RunTest(0x20); + RunTest(0x20); + RunTest(true); + RunTest(0x1234); + RunTest(0x1234); + RunTest('x'); + RunTest(0x12345678); + RunTest(0x12345678); + RunTest(0x0123456789abcdef); + RunTest(0x0123456789abcdef); + RunTest(unchecked((nint)0x0123456789abcdef)); + RunTest(unchecked((nuint)0x0123456789abcdef)); + RunTest((Half)1.0); + RunTest(1.0f); + RunTest(1.0); + RunTest(StringComparison.CurrentCultureIgnoreCase); // should be treated as underlying primitive + RunTest("Hello world!"); // ref type, no SIMD + RunTest(1.0m); // 128-bit struct + RunTest(new Guid("29e07627-2481-4f43-8fbf-09cf21180239")); // 128-bit struct + RunTest(new(0x11111111, 0x22222222, 0x33333333)); // 96-bit struct, no SIMD + RunTest(new(0x1111111111111111, 0x2222222222222222, 0x3333333333333333, 0x4444444444444444)); + RunTest(new( + 0x1111111111111111, 0x2222222222222222, 0x3333333333333333, 0x4444444444444444, + 0x5555555555555555, 0x6666666666666666, 0x7777777777777777, 0x8888888888888888)); // 512-bit struct, no SIMD + RunTest(new("Hello world!")); // struct contains refs, no SIMD + + static void RunTest(T value) + { + T[] arr = new T[128]; + + // Run tests for lengths := 0 to 64, ensuring we don't overrun our buffer + + for (int i = 0; i <= 64; i++) + { + arr.AsSpan(0, i).Fill(value); + Assert.Equal(Enumerable.Repeat(value, i), arr.Take(i)); // first i entries should've been populated with 'value' + Assert.Equal(Enumerable.Repeat(default(T), arr.Length - i), arr.Skip(i)); // remaining entries should contain default(T) + Array.Clear(arr, 0, arr.Length); + } + } + } + + private readonly struct My96BitStruct + { + public My96BitStruct(int data0, int data1, int data2) + { + Data0 = data0; + Data1 = data1; + Data2 = data2; + } + + public readonly int Data0; + public readonly int Data1; + public readonly int Data2; + } + + private readonly struct My256BitStruct + { + public My256BitStruct(ulong data0, ulong data1, ulong data2, ulong data3) + { + Data0 = data0; + Data1 = data1; + Data2 = data2; + Data3 = data3; + } + + public readonly ulong Data0; + public readonly ulong Data1; + public readonly ulong Data2; + public readonly ulong Data3; + } + + private readonly struct My512BitStruct + { + public My512BitStruct(ulong data0, ulong data1, ulong data2, ulong data3, ulong data4, ulong data5, ulong data6, ulong data7) + { + Data0 = data0; + Data1 = data1; + Data2 = data2; + Data3 = data3; + Data4 = data4; + Data5 = data5; + Data6 = data6; + Data7 = data7; + } + + public readonly ulong Data0; + public readonly ulong Data1; + public readonly ulong Data2; + public readonly ulong Data3; + public readonly ulong Data4; + public readonly ulong Data5; + public readonly ulong Data6; + public readonly ulong Data7; + } + + private readonly struct MyRefContainingStruct + { + public MyRefContainingStruct(object data) + { + Data = data; + } + + public readonly object Data; + } } } diff --git a/src/libraries/System.Private.CoreLib/src/System/Span.cs b/src/libraries/System.Private.CoreLib/src/System/Span.cs index 60de1c0328665..753e9f279f374 100644 --- a/src/libraries/System.Private.CoreLib/src/System/Span.cs +++ b/src/libraries/System.Private.CoreLib/src/System/Span.cs @@ -5,7 +5,6 @@ using System.Runtime.CompilerServices; using System.Runtime.InteropServices; using System.Runtime.Versioning; -using System.Text; using EditorBrowsableAttribute = System.ComponentModel.EditorBrowsableAttribute; using EditorBrowsableState = System.ComponentModel.EditorBrowsableState; using Internal.Runtime.CompilerServices; @@ -280,53 +279,28 @@ public unsafe void Clear() /// /// Fills the contents of this span with the given value. /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] public void Fill(T value) { if (Unsafe.SizeOf() == 1) { - uint length = (uint)_length; - if (length == 0) - return; - - T tmp = value; // Avoid taking address of the "value" argument. It would regress performance of the loop below. - Unsafe.InitBlockUnaligned(ref Unsafe.As(ref _pointer.Value), Unsafe.As(ref tmp), length); +#if MONO + // Mono runtime's implementation of initblk performs a null check on the address. + // We'll perform a length check here to avoid passing a null address in the empty span case. + if (_length != 0) +#endif + { + // Special-case single-byte types like byte / sbyte / bool. + // The runtime eventually calls memset, which can efficiently support large buffers. + // We don't need to check IsReferenceOrContainsReferences because no references + // can ever be stored in types this small. + Unsafe.InitBlockUnaligned(ref Unsafe.As(ref _pointer.Value), Unsafe.As(ref value), (uint)_length); + } } else { - // Do all math as nuint to avoid unnecessary 64->32->64 bit integer truncations - nuint length = (uint)_length; - if (length == 0) - return; - - ref T r = ref _pointer.Value; - - // TODO: Create block fill for value types of power of two sizes e.g. 2,4,8,16 - - nuint elementSize = (uint)Unsafe.SizeOf(); - nuint i = 0; - for (; i < (length & ~(nuint)7); i += 8) - { - Unsafe.AddByteOffset(ref r, (i + 0) * elementSize) = value; - Unsafe.AddByteOffset(ref r, (i + 1) * elementSize) = value; - Unsafe.AddByteOffset(ref r, (i + 2) * elementSize) = value; - Unsafe.AddByteOffset(ref r, (i + 3) * elementSize) = value; - Unsafe.AddByteOffset(ref r, (i + 4) * elementSize) = value; - Unsafe.AddByteOffset(ref r, (i + 5) * elementSize) = value; - Unsafe.AddByteOffset(ref r, (i + 6) * elementSize) = value; - Unsafe.AddByteOffset(ref r, (i + 7) * elementSize) = value; - } - if (i < (length & ~(nuint)3)) - { - Unsafe.AddByteOffset(ref r, (i + 0) * elementSize) = value; - Unsafe.AddByteOffset(ref r, (i + 1) * elementSize) = value; - Unsafe.AddByteOffset(ref r, (i + 2) * elementSize) = value; - Unsafe.AddByteOffset(ref r, (i + 3) * elementSize) = value; - i += 4; - } - for (; i < length; i++) - { - Unsafe.AddByteOffset(ref r, i * elementSize) = value; - } + // Call our optimized workhorse method for all other types. + SpanHelpers.Fill(ref _pointer.Value, (uint)_length, value); } } diff --git a/src/libraries/System.Private.CoreLib/src/System/SpanHelpers.T.cs b/src/libraries/System.Private.CoreLib/src/System/SpanHelpers.T.cs index 2f2763f24a1a1..339aa157f8b7e 100644 --- a/src/libraries/System.Private.CoreLib/src/System/SpanHelpers.T.cs +++ b/src/libraries/System.Private.CoreLib/src/System/SpanHelpers.T.cs @@ -2,13 +2,189 @@ // The .NET Foundation licenses this file to you under the MIT license. using System.Diagnostics; - +using System.Numerics; +using System.Runtime.CompilerServices; +using System.Runtime.Intrinsics; using Internal.Runtime.CompilerServices; namespace System { internal static partial class SpanHelpers // .T { + public static void Fill(ref T refData, nuint numElements, T value) + { + // Early checks to see if it's even possible to vectorize - JIT will turn these checks into consts. + // - T cannot contain references (GC can't track references in vectors) + // - Vectorization must be hardware-accelerated + // - T's size must not exceed the vector's size + // - T's size must be a whole power of 2 + + if (RuntimeHelpers.IsReferenceOrContainsReferences()) { goto CannotVectorize; } + if (!Vector.IsHardwareAccelerated) { goto CannotVectorize; } + if (Unsafe.SizeOf() > Vector.Count) { goto CannotVectorize; } + if (!BitOperations.IsPow2(Unsafe.SizeOf())) { goto CannotVectorize; } + + if (numElements >= (uint)(Vector.Count / Unsafe.SizeOf())) + { + // We have enough data for at least one vectorized write. + + T tmp = value; // Avoid taking address of the "value" argument. It would regress performance of the loops below. + Vector vector; + + if (Unsafe.SizeOf() == 1) + { + vector = new Vector(Unsafe.As(ref tmp)); + } + else if (Unsafe.SizeOf() == 2) + { + vector = (Vector)(new Vector(Unsafe.As(ref tmp))); + } + else if (Unsafe.SizeOf() == 4) + { + // special-case float since it's already passed in a SIMD reg + vector = (typeof(T) == typeof(float)) + ? (Vector)(new Vector((float)(object)tmp!)) + : (Vector)(new Vector(Unsafe.As(ref tmp))); + } + else if (Unsafe.SizeOf() == 8) + { + // special-case double since it's already passed in a SIMD reg + vector = (typeof(T) == typeof(double)) + ? (Vector)(new Vector((double)(object)tmp!)) + : (Vector)(new Vector(Unsafe.As(ref tmp))); + } + else if (Unsafe.SizeOf() == 16) + { + Vector128 vec128 = Unsafe.As>(ref tmp); + if (Vector.Count == 16) + { + vector = vec128.AsVector(); + } + else if (Vector.Count == 32) + { + vector = Vector256.Create(vec128, vec128).AsVector(); + } + else + { + Debug.Fail("Vector isn't 128 or 256 bits in size?"); + goto CannotVectorize; + } + } + else if (Unsafe.SizeOf() == 32) + { + if (Vector.Count == 32) + { + vector = Unsafe.As>(ref tmp).AsVector(); + } + else + { + Debug.Fail("Vector isn't 256 bits in size?"); + goto CannotVectorize; + } + } + else + { + Debug.Fail("Vector is greater than 256 bits in size?"); + goto CannotVectorize; + } + + ref byte refDataAsBytes = ref Unsafe.As(ref refData); + nuint totalByteLength = numElements * (nuint)Unsafe.SizeOf(); // get this calculation ready ahead of time + nuint stopLoopAtOffset = totalByteLength & (nuint)(nint)(2 * (int)-Vector.Count); // intentional sign extension carries the negative bit + nuint offset = 0; + + // Loop, writing 2 vectors at a time. + // Compare 'numElements' rather than 'stopLoopAtOffset' because we don't want a dependency + // on the very recently calculated 'stopLoopAtOffset' value. + + if (numElements >= (uint)(2 * Vector.Count / Unsafe.SizeOf())) + { + do + { + Unsafe.WriteUnaligned(ref Unsafe.AddByteOffset(ref refDataAsBytes, offset), vector); + Unsafe.WriteUnaligned(ref Unsafe.AddByteOffset(ref refDataAsBytes, offset + (nuint)Vector.Count), vector); + offset += (uint)(2 * Vector.Count); + } while (offset < stopLoopAtOffset); + } + + // At this point, if any data remains to be written, it's strictly less than + // 2 * sizeof(Vector) bytes. The loop above had us write an even number of vectors. + // If the total byte length instead involves us writing an odd number of vectors, write + // one additional vector now. The bit check below tells us if we're in an "odd vector + // count" situation. + + if ((totalByteLength & (nuint)Vector.Count) != 0) + { + Unsafe.WriteUnaligned(ref Unsafe.AddByteOffset(ref refDataAsBytes, offset), vector); + } + + // It's possible that some small buffer remains to be populated - something that won't + // fit an entire vector's worth of data. Instead of falling back to a loop, we'll write + // a vector at the very end of the buffer. This may involve overwriting previously + // populated data, which is fine since we're splatting the same value for all entries. + // There's no need to perform a length check here because we already performed this + // check before entering the vectorized code path. + + Unsafe.WriteUnaligned(ref Unsafe.AddByteOffset(ref refDataAsBytes, totalByteLength - (nuint)Vector.Count), vector); + + // And we're done! + + return; + } + + CannotVectorize: + + // If we reached this point, we cannot vectorize this T, or there are too few + // elements for us to vectorize. Fall back to an unrolled loop. + + nuint i = 0; + + // Write 8 elements at a time + + if (numElements >= 8) + { + nuint stopLoopAtOffset = numElements & ~(nuint)7; + do + { + Unsafe.Add(ref refData, (nint)i + 0) = value; + Unsafe.Add(ref refData, (nint)i + 1) = value; + Unsafe.Add(ref refData, (nint)i + 2) = value; + Unsafe.Add(ref refData, (nint)i + 3) = value; + Unsafe.Add(ref refData, (nint)i + 4) = value; + Unsafe.Add(ref refData, (nint)i + 5) = value; + Unsafe.Add(ref refData, (nint)i + 6) = value; + Unsafe.Add(ref refData, (nint)i + 7) = value; + } while ((i += 8) < stopLoopAtOffset); + } + + // Write next 4 elements if needed + + if ((numElements & 4) != 0) + { + Unsafe.Add(ref refData, (nint)i + 0) = value; + Unsafe.Add(ref refData, (nint)i + 1) = value; + Unsafe.Add(ref refData, (nint)i + 2) = value; + Unsafe.Add(ref refData, (nint)i + 3) = value; + i += 4; + } + + // Write next 2 elements if needed + + if ((numElements & 2) != 0) + { + Unsafe.Add(ref refData, (nint)i + 0) = value; + Unsafe.Add(ref refData, (nint)i + 1) = value; + i += 2; + } + + // Write final element if needed + + if ((numElements & 1) != 0) + { + Unsafe.Add(ref refData, (nint)i) = value; + } + } + public static int IndexOf(ref T searchSpace, int searchSpaceLength, ref T value, int valueLength) where T : IEquatable { Debug.Assert(searchSpaceLength >= 0);