From 85a68b0638241595395dd4756a29818c2a4c9d7e Mon Sep 17 00:00:00 2001 From: Stephen Toub Date: Thu, 21 Sep 2023 10:45:44 -0400 Subject: [PATCH] Use FMA in TensorPrimitives (#92205) --- .../Tensors/TensorPrimitives.netcore.cs | 75 +++++++++++++++---- .../tests/TensorPrimitivesTests.cs | 40 +++++----- 2 files changed, 79 insertions(+), 36 deletions(-) diff --git a/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.netcore.cs b/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.netcore.cs index ae5af404ac1af..652771045b2d5 100644 --- a/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.netcore.cs +++ b/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.netcore.cs @@ -1,9 +1,12 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. +using System.Diagnostics; using System.Runtime.CompilerServices; using System.Runtime.InteropServices; using System.Runtime.Intrinsics; +using System.Runtime.Intrinsics.Arm; +using System.Runtime.Intrinsics.X86; namespace System.Numerics.Tensors { @@ -80,9 +83,9 @@ private static float CosineSimilarityCore(ReadOnlySpan x, ReadOnlySpan xVec = Vector512.LoadUnsafe(ref xRef, (uint)i); Vector512 yVec = Vector512.LoadUnsafe(ref yRef, (uint)i); - dotProductVector += xVec * yVec; - xSumOfSquaresVector += xVec * xVec; - ySumOfSquaresVector += yVec * yVec; + dotProductVector = FusedMultiplyAdd(xVec, yVec, dotProductVector); + xSumOfSquaresVector = FusedMultiplyAdd(xVec, xVec, xSumOfSquaresVector); + ySumOfSquaresVector = FusedMultiplyAdd(yVec, yVec, ySumOfSquaresVector); i += Vector512.Count; } @@ -111,9 +114,9 @@ private static float CosineSimilarityCore(ReadOnlySpan x, ReadOnlySpan xVec = Vector256.LoadUnsafe(ref xRef, (uint)i); Vector256 yVec = Vector256.LoadUnsafe(ref yRef, (uint)i); - dotProductVector += xVec * yVec; - xSumOfSquaresVector += xVec * xVec; - ySumOfSquaresVector += yVec * yVec; + dotProductVector = FusedMultiplyAdd(xVec, yVec, dotProductVector); + xSumOfSquaresVector = FusedMultiplyAdd(xVec, xVec, xSumOfSquaresVector); + ySumOfSquaresVector = FusedMultiplyAdd(yVec, yVec, ySumOfSquaresVector); i += Vector256.Count; } @@ -140,9 +143,9 @@ private static float CosineSimilarityCore(ReadOnlySpan x, ReadOnlySpan xVec = Vector128.LoadUnsafe(ref xRef, (uint)i); Vector128 yVec = Vector128.LoadUnsafe(ref yRef, (uint)i); - dotProductVector += xVec * yVec; - xSumOfSquaresVector += xVec * xVec; - ySumOfSquaresVector += yVec * yVec; + dotProductVector = FusedMultiplyAdd(xVec, yVec, dotProductVector); + xSumOfSquaresVector = FusedMultiplyAdd(xVec, xVec, xSumOfSquaresVector); + ySumOfSquaresVector = FusedMultiplyAdd(yVec, yVec, ySumOfSquaresVector); i += Vector128.Count; } @@ -157,9 +160,9 @@ private static float CosineSimilarityCore(ReadOnlySpan x, ReadOnlySpan( } } + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static Vector128 FusedMultiplyAdd(Vector128 x, Vector128 y, Vector128 addend) + { + if (Fma.IsSupported) + { + return Fma.MultiplyAdd(x, y, addend); + } + + if (AdvSimd.IsSupported) + { + return AdvSimd.FusedMultiplyAdd(addend, x, y); + } + + return (x * y) + addend; + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static Vector256 FusedMultiplyAdd(Vector256 x, Vector256 y, Vector256 addend) + { + if (Fma.IsSupported) + { + return Fma.MultiplyAdd(x, y, addend); + } + + return (x * y) + addend; + } + +#if NET8_0_OR_GREATER + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static Vector512 FusedMultiplyAdd(Vector512 x, Vector512 y, Vector512 addend) + { + if (Avx512F.IsSupported) + { + return Avx512F.FusedMultiplyAdd(x, y, addend); + } + + return (x * y) + addend; + } +#endif + private readonly struct AddOperator : IBinaryOperator { public static float Invoke(float x, float y) => x + y; @@ -1176,11 +1219,11 @@ public static float Invoke(Vector512 x) private readonly struct MultiplyAddOperator : ITernaryOperator { - public static float Invoke(float x, float y, float z) => (x * y) + z; - public static Vector128 Invoke(Vector128 x, Vector128 y, Vector128 z) => (x * y) + z; - public static Vector256 Invoke(Vector256 x, Vector256 y, Vector256 z) => (x * y) + z; + public static float Invoke(float x, float y, float z) => MathF.FusedMultiplyAdd(x, y, z); + public static Vector128 Invoke(Vector128 x, Vector128 y, Vector128 z) => FusedMultiplyAdd(x, y, z); + public static Vector256 Invoke(Vector256 x, Vector256 y, Vector256 z) => FusedMultiplyAdd(x, y, z); #if NET8_0_OR_GREATER - public static Vector512 Invoke(Vector512 x, Vector512 y, Vector512 z) => (x * y) + z; + public static Vector512 Invoke(Vector512 x, Vector512 y, Vector512 z) => FusedMultiplyAdd(x, y, z); #endif } diff --git a/src/libraries/System.Numerics.Tensors/tests/TensorPrimitivesTests.cs b/src/libraries/System.Numerics.Tensors/tests/TensorPrimitivesTests.cs index 181d152e6ae97..d566c9d3cd04a 100644 --- a/src/libraries/System.Numerics.Tensors/tests/TensorPrimitivesTests.cs +++ b/src/libraries/System.Numerics.Tensors/tests/TensorPrimitivesTests.cs @@ -56,7 +56,7 @@ public static void AddTwoTensors(int tensorLength) for (int i = 0; i < tensorLength; i++) { - Assert.Equal(x[i] + y[i], destination[i]); + Assert.Equal(x[i] + y[i], destination[i], Tolerance); } } @@ -94,7 +94,7 @@ public static void AddTensorAndScalar(int tensorLength) for (int i = 0; i < tensorLength; i++) { - Assert.Equal(x[i] + y, destination[i]); + Assert.Equal(x[i] + y, destination[i], Tolerance); } } @@ -121,7 +121,7 @@ public static void SubtractTwoTensors(int tensorLength) for (int i = 0; i < tensorLength; i++) { - Assert.Equal(x[i] - y[i], destination[i]); + Assert.Equal(x[i] - y[i], destination[i], Tolerance); } } @@ -159,7 +159,7 @@ public static void SubtractTensorAndScalar(int tensorLength) for (int i = 0; i < tensorLength; i++) { - Assert.Equal(x[i] - y, destination[i]); + Assert.Equal(x[i] - y, destination[i], Tolerance); } } @@ -186,7 +186,7 @@ public static void MultiplyTwoTensors(int tensorLength) for (int i = 0; i < tensorLength; i++) { - Assert.Equal(x[i] * y[i], destination[i]); + Assert.Equal(x[i] * y[i], destination[i], Tolerance); } } @@ -224,7 +224,7 @@ public static void MultiplyTensorAndScalar(int tensorLength) for (int i = 0; i < tensorLength; i++) { - Assert.Equal(x[i] * y, destination[i]); + Assert.Equal(x[i] * y, destination[i], Tolerance); } } @@ -251,7 +251,7 @@ public static void DivideTwoTensors(int tensorLength) for (int i = 0; i < tensorLength; i++) { - Assert.Equal(x[i] / y[i], destination[i]); + Assert.Equal(x[i] / y[i], destination[i], Tolerance); } } @@ -289,7 +289,7 @@ public static void DivideTensorAndScalar(int tensorLength) for (int i = 0; i < tensorLength; i++) { - Assert.Equal(x[i] / y, destination[i]); + Assert.Equal(x[i] / y, destination[i], Tolerance); } } @@ -315,7 +315,7 @@ public static void NegateTensor(int tensorLength) for (int i = 0; i < tensorLength; i++) { - Assert.Equal(-x[i], destination[i]); + Assert.Equal(-x[i], destination[i], Tolerance); } } @@ -342,7 +342,7 @@ public static void AddTwoTensorsAndMultiplyWithThirdTensor(int tensorLength) for (int i = 0; i < tensorLength; i++) { - Assert.Equal((x[i] + y[i]) * multiplier[i], destination[i]); + Assert.Equal((x[i] + y[i]) * multiplier[i], destination[i], Tolerance); } } @@ -395,7 +395,7 @@ public static void AddTwoTensorsAndMultiplyWithScalar(int tensorLength) for (int i = 0; i < tensorLength; i++) { - Assert.Equal((x[i] + y[i]) * multiplier, destination[i]); + Assert.Equal((x[i] + y[i]) * multiplier, destination[i], Tolerance); } } @@ -436,7 +436,7 @@ public static void AddTensorAndScalarAndMultiplyWithTensor(int tensorLength) for (int i = 0; i < tensorLength; i++) { - Assert.Equal((x[i] + y) * multiplier[i], destination[i]); + Assert.Equal((x[i] + y) * multiplier[i], destination[i], Tolerance); } } @@ -477,7 +477,7 @@ public static void MultiplyTwoTensorsAndAddWithThirdTensor(int tensorLength) for (int i = 0; i < tensorLength; i++) { - Assert.Equal((x[i] * y[i]) + addend[i], destination[i]); + Assert.Equal((x[i] * y[i]) + addend[i], destination[i], Tolerance); } } @@ -530,7 +530,7 @@ public static void MultiplyTwoTensorsAndAddWithScalar(int tensorLength) for (int i = 0; i < tensorLength; i++) { - Assert.Equal((x[i] * y[i]) + addend, destination[i]); + Assert.Equal((x[i] * y[i]) + addend, destination[i], Tolerance); } } @@ -559,7 +559,7 @@ public static void MultiplyTensorAndScalarAndAddWithTensor(int tensorLength) for (int i = 0; i < tensorLength; i++) { - Assert.Equal((x[i] * y) + addend[i], destination[i]); + Assert.Equal((x[i] * y) + addend[i], destination[i], Tolerance); } } @@ -586,7 +586,7 @@ public static void ExpTensor(int tensorLength) for (int i = 0; i < tensorLength; i++) { - Assert.Equal(MathF.Exp(x[i]), destination[i]); + Assert.Equal(MathF.Exp(x[i]), destination[i], Tolerance); } } @@ -611,7 +611,7 @@ public static void LogTensor(int tensorLength) for (int i = 0; i < tensorLength; i++) { - Assert.Equal(MathF.Log(x[i]), destination[i]); + Assert.Equal(MathF.Log(x[i]), destination[i], Tolerance); } } @@ -636,7 +636,7 @@ public static void CoshTensor(int tensorLength) for (int i = 0; i < tensorLength; i++) { - Assert.Equal(MathF.Cosh(x[i]), destination[i]); + Assert.Equal(MathF.Cosh(x[i]), destination[i], Tolerance); } } @@ -661,7 +661,7 @@ public static void SinhTensor(int tensorLength) for (int i = 0; i < tensorLength; i++) { - Assert.Equal(MathF.Sinh(x[i]), destination[i]); + Assert.Equal(MathF.Sinh(x[i]), destination[i], Tolerance); } } @@ -686,7 +686,7 @@ public static void TanhTensor(int tensorLength) for (int i = 0; i < tensorLength; i++) { - Assert.Equal(MathF.Tanh(x[i]), destination[i]); + Assert.Equal(MathF.Tanh(x[i]), destination[i], Tolerance); } }