Skip to content

Commit

Permalink
Use FMA in TensorPrimitives (#92205)
Browse files Browse the repository at this point in the history
  • Loading branch information
stephentoub committed Sep 21, 2023
1 parent 9c8ff4c commit 85a68b0
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 36 deletions.
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System.Diagnostics;
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
using System.Runtime.Intrinsics;
using System.Runtime.Intrinsics.Arm;
using System.Runtime.Intrinsics.X86;

namespace System.Numerics.Tensors
{
Expand Down Expand Up @@ -80,9 +83,9 @@ private static float CosineSimilarityCore(ReadOnlySpan<float> x, ReadOnlySpan<fl
Vector512<float> xVec = Vector512.LoadUnsafe(ref xRef, (uint)i);
Vector512<float> yVec = Vector512.LoadUnsafe(ref yRef, (uint)i);

dotProductVector += xVec * yVec;
xSumOfSquaresVector += xVec * xVec;
ySumOfSquaresVector += yVec * yVec;
dotProductVector = FusedMultiplyAdd(xVec, yVec, dotProductVector);
xSumOfSquaresVector = FusedMultiplyAdd(xVec, xVec, xSumOfSquaresVector);
ySumOfSquaresVector = FusedMultiplyAdd(yVec, yVec, ySumOfSquaresVector);

i += Vector512<float>.Count;
}
Expand Down Expand Up @@ -111,9 +114,9 @@ private static float CosineSimilarityCore(ReadOnlySpan<float> x, ReadOnlySpan<fl
Vector256<float> xVec = Vector256.LoadUnsafe(ref xRef, (uint)i);
Vector256<float> yVec = Vector256.LoadUnsafe(ref yRef, (uint)i);

dotProductVector += xVec * yVec;
xSumOfSquaresVector += xVec * xVec;
ySumOfSquaresVector += yVec * yVec;
dotProductVector = FusedMultiplyAdd(xVec, yVec, dotProductVector);
xSumOfSquaresVector = FusedMultiplyAdd(xVec, xVec, xSumOfSquaresVector);
ySumOfSquaresVector = FusedMultiplyAdd(yVec, yVec, ySumOfSquaresVector);

i += Vector256<float>.Count;
}
Expand All @@ -140,9 +143,9 @@ private static float CosineSimilarityCore(ReadOnlySpan<float> x, ReadOnlySpan<fl
Vector128<float> xVec = Vector128.LoadUnsafe(ref xRef, (uint)i);
Vector128<float> yVec = Vector128.LoadUnsafe(ref yRef, (uint)i);

dotProductVector += xVec * yVec;
xSumOfSquaresVector += xVec * xVec;
ySumOfSquaresVector += yVec * yVec;
dotProductVector = FusedMultiplyAdd(xVec, yVec, dotProductVector);
xSumOfSquaresVector = FusedMultiplyAdd(xVec, xVec, xSumOfSquaresVector);
ySumOfSquaresVector = FusedMultiplyAdd(yVec, yVec, ySumOfSquaresVector);

i += Vector128<float>.Count;
}
Expand All @@ -157,9 +160,9 @@ private static float CosineSimilarityCore(ReadOnlySpan<float> x, ReadOnlySpan<fl
// Process any remaining elements past the last vector.
for (; (uint)i < (uint)x.Length; i++)
{
dotProduct += x[i] * y[i];
xSumOfSquares += x[i] * x[i];
ySumOfSquares += y[i] * y[i];
dotProduct = MathF.FusedMultiplyAdd(x[i], y[i], dotProduct);
xSumOfSquares = MathF.FusedMultiplyAdd(x[i], x[i], xSumOfSquares);
ySumOfSquares = MathF.FusedMultiplyAdd(y[i], y[i], ySumOfSquares);
}

// Sum(X * Y) / (|X| * |Y|)
Expand Down Expand Up @@ -1026,6 +1029,46 @@ private static unsafe void InvokeSpanScalarSpanIntoSpan<TTernaryOperator>(
}
}

[MethodImpl(MethodImplOptions.AggressiveInlining)]
private static Vector128<float> FusedMultiplyAdd(Vector128<float> x, Vector128<float> y, Vector128<float> addend)
{
if (Fma.IsSupported)
{
return Fma.MultiplyAdd(x, y, addend);
}

if (AdvSimd.IsSupported)
{
return AdvSimd.FusedMultiplyAdd(addend, x, y);
}

return (x * y) + addend;
}

[MethodImpl(MethodImplOptions.AggressiveInlining)]
private static Vector256<float> FusedMultiplyAdd(Vector256<float> x, Vector256<float> y, Vector256<float> addend)
{
if (Fma.IsSupported)
{
return Fma.MultiplyAdd(x, y, addend);
}

return (x * y) + addend;
}

#if NET8_0_OR_GREATER
[MethodImpl(MethodImplOptions.AggressiveInlining)]
private static Vector512<float> FusedMultiplyAdd(Vector512<float> x, Vector512<float> y, Vector512<float> addend)
{
if (Avx512F.IsSupported)
{
return Avx512F.FusedMultiplyAdd(x, y, addend);
}

return (x * y) + addend;
}
#endif

private readonly struct AddOperator : IBinaryOperator
{
public static float Invoke(float x, float y) => x + y;
Expand Down Expand Up @@ -1176,11 +1219,11 @@ public static float Invoke(Vector512<float> x)

private readonly struct MultiplyAddOperator : ITernaryOperator
{
public static float Invoke(float x, float y, float z) => (x * y) + z;
public static Vector128<float> Invoke(Vector128<float> x, Vector128<float> y, Vector128<float> z) => (x * y) + z;
public static Vector256<float> Invoke(Vector256<float> x, Vector256<float> y, Vector256<float> z) => (x * y) + z;
public static float Invoke(float x, float y, float z) => MathF.FusedMultiplyAdd(x, y, z);
public static Vector128<float> Invoke(Vector128<float> x, Vector128<float> y, Vector128<float> z) => FusedMultiplyAdd(x, y, z);
public static Vector256<float> Invoke(Vector256<float> x, Vector256<float> y, Vector256<float> z) => FusedMultiplyAdd(x, y, z);
#if NET8_0_OR_GREATER
public static Vector512<float> Invoke(Vector512<float> x, Vector512<float> y, Vector512<float> z) => (x * y) + z;
public static Vector512<float> Invoke(Vector512<float> x, Vector512<float> y, Vector512<float> z) => FusedMultiplyAdd(x, y, z);
#endif
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ public static void AddTwoTensors(int tensorLength)

for (int i = 0; i < tensorLength; i++)
{
Assert.Equal(x[i] + y[i], destination[i]);
Assert.Equal(x[i] + y[i], destination[i], Tolerance);
}
}

Expand Down Expand Up @@ -94,7 +94,7 @@ public static void AddTensorAndScalar(int tensorLength)

for (int i = 0; i < tensorLength; i++)
{
Assert.Equal(x[i] + y, destination[i]);
Assert.Equal(x[i] + y, destination[i], Tolerance);
}
}

Expand All @@ -121,7 +121,7 @@ public static void SubtractTwoTensors(int tensorLength)

for (int i = 0; i < tensorLength; i++)
{
Assert.Equal(x[i] - y[i], destination[i]);
Assert.Equal(x[i] - y[i], destination[i], Tolerance);
}
}

Expand Down Expand Up @@ -159,7 +159,7 @@ public static void SubtractTensorAndScalar(int tensorLength)

for (int i = 0; i < tensorLength; i++)
{
Assert.Equal(x[i] - y, destination[i]);
Assert.Equal(x[i] - y, destination[i], Tolerance);
}
}

Expand All @@ -186,7 +186,7 @@ public static void MultiplyTwoTensors(int tensorLength)

for (int i = 0; i < tensorLength; i++)
{
Assert.Equal(x[i] * y[i], destination[i]);
Assert.Equal(x[i] * y[i], destination[i], Tolerance);
}
}

Expand Down Expand Up @@ -224,7 +224,7 @@ public static void MultiplyTensorAndScalar(int tensorLength)

for (int i = 0; i < tensorLength; i++)
{
Assert.Equal(x[i] * y, destination[i]);
Assert.Equal(x[i] * y, destination[i], Tolerance);
}
}

Expand All @@ -251,7 +251,7 @@ public static void DivideTwoTensors(int tensorLength)

for (int i = 0; i < tensorLength; i++)
{
Assert.Equal(x[i] / y[i], destination[i]);
Assert.Equal(x[i] / y[i], destination[i], Tolerance);
}
}

Expand Down Expand Up @@ -289,7 +289,7 @@ public static void DivideTensorAndScalar(int tensorLength)

for (int i = 0; i < tensorLength; i++)
{
Assert.Equal(x[i] / y, destination[i]);
Assert.Equal(x[i] / y, destination[i], Tolerance);
}
}

Expand All @@ -315,7 +315,7 @@ public static void NegateTensor(int tensorLength)

for (int i = 0; i < tensorLength; i++)
{
Assert.Equal(-x[i], destination[i]);
Assert.Equal(-x[i], destination[i], Tolerance);
}
}

Expand All @@ -342,7 +342,7 @@ public static void AddTwoTensorsAndMultiplyWithThirdTensor(int tensorLength)

for (int i = 0; i < tensorLength; i++)
{
Assert.Equal((x[i] + y[i]) * multiplier[i], destination[i]);
Assert.Equal((x[i] + y[i]) * multiplier[i], destination[i], Tolerance);
}
}

Expand Down Expand Up @@ -395,7 +395,7 @@ public static void AddTwoTensorsAndMultiplyWithScalar(int tensorLength)

for (int i = 0; i < tensorLength; i++)
{
Assert.Equal((x[i] + y[i]) * multiplier, destination[i]);
Assert.Equal((x[i] + y[i]) * multiplier, destination[i], Tolerance);
}
}

Expand Down Expand Up @@ -436,7 +436,7 @@ public static void AddTensorAndScalarAndMultiplyWithTensor(int tensorLength)

for (int i = 0; i < tensorLength; i++)
{
Assert.Equal((x[i] + y) * multiplier[i], destination[i]);
Assert.Equal((x[i] + y) * multiplier[i], destination[i], Tolerance);
}
}

Expand Down Expand Up @@ -477,7 +477,7 @@ public static void MultiplyTwoTensorsAndAddWithThirdTensor(int tensorLength)

for (int i = 0; i < tensorLength; i++)
{
Assert.Equal((x[i] * y[i]) + addend[i], destination[i]);
Assert.Equal((x[i] * y[i]) + addend[i], destination[i], Tolerance);
}
}

Expand Down Expand Up @@ -530,7 +530,7 @@ public static void MultiplyTwoTensorsAndAddWithScalar(int tensorLength)

for (int i = 0; i < tensorLength; i++)
{
Assert.Equal((x[i] * y[i]) + addend, destination[i]);
Assert.Equal((x[i] * y[i]) + addend, destination[i], Tolerance);
}
}

Expand Down Expand Up @@ -559,7 +559,7 @@ public static void MultiplyTensorAndScalarAndAddWithTensor(int tensorLength)

for (int i = 0; i < tensorLength; i++)
{
Assert.Equal((x[i] * y) + addend[i], destination[i]);
Assert.Equal((x[i] * y) + addend[i], destination[i], Tolerance);
}
}

Expand All @@ -586,7 +586,7 @@ public static void ExpTensor(int tensorLength)

for (int i = 0; i < tensorLength; i++)
{
Assert.Equal(MathF.Exp(x[i]), destination[i]);
Assert.Equal(MathF.Exp(x[i]), destination[i], Tolerance);
}
}

Expand All @@ -611,7 +611,7 @@ public static void LogTensor(int tensorLength)

for (int i = 0; i < tensorLength; i++)
{
Assert.Equal(MathF.Log(x[i]), destination[i]);
Assert.Equal(MathF.Log(x[i]), destination[i], Tolerance);
}
}

Expand All @@ -636,7 +636,7 @@ public static void CoshTensor(int tensorLength)

for (int i = 0; i < tensorLength; i++)
{
Assert.Equal(MathF.Cosh(x[i]), destination[i]);
Assert.Equal(MathF.Cosh(x[i]), destination[i], Tolerance);
}
}

Expand All @@ -661,7 +661,7 @@ public static void SinhTensor(int tensorLength)

for (int i = 0; i < tensorLength; i++)
{
Assert.Equal(MathF.Sinh(x[i]), destination[i]);
Assert.Equal(MathF.Sinh(x[i]), destination[i], Tolerance);
}
}

Expand All @@ -686,7 +686,7 @@ public static void TanhTensor(int tensorLength)

for (int i = 0; i < tensorLength; i++)
{
Assert.Equal(MathF.Tanh(x[i]), destination[i]);
Assert.Equal(MathF.Tanh(x[i]), destination[i], Tolerance);
}
}

Expand Down

0 comments on commit 85a68b0

Please sign in to comment.