Skip to content

Commit

Permalink
Added support for Generic Math using net70 SDK.
Browse files Browse the repository at this point in the history
  • Loading branch information
MoFtZ authored and m4rs-mt committed Dec 5, 2022
1 parent f4b3445 commit 7bf5578
Show file tree
Hide file tree
Showing 4 changed files with 159 additions and 4 deletions.
1 change: 1 addition & 0 deletions Src/ILGPU.Tests/Configurations.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ DisassemblerTests
EnumValues
FixedBuffers
GetKernelTests
GenericMath
Indices
KernelEntryPoints
MemoryBufferOperations
Expand Down
99 changes: 99 additions & 0 deletions Src/ILGPU.Tests/GenericMath.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
// ---------------------------------------------------------------------------------------
// ILGPU
// Copyright (c) 2022 ILGPU Project
// www.ilgpu.net
//
// File: GenericMath.cs
//
// This file is part of ILGPU and is distributed under the University of Illinois Open
// Source License. See LICENSE.txt for details.
// ---------------------------------------------------------------------------------------

using ILGPU.Runtime;
using System.Linq;
using System.Numerics;
using Xunit;
using Xunit.Abstractions;

namespace ILGPU.Tests
{
public abstract class GenericMath : TestBase
{
protected GenericMath(ITestOutputHelper output, TestContext testContext)
: base(output, testContext)
{ }

#if NET7_0_OR_GREATER

private const int Length = 1024;

public static T GeZeroIfBigger<T>(T value, T max) where T : INumber<T>
{
if (value > max)
return T.Zero;
return value;
}

internal static void GenericMathKernel<T>(
Index1D index,
ArrayView1D<T, Stride1D.Dense> input,
ArrayView1D<T, Stride1D.Dense> output,
T maxValue)
where T : unmanaged, INumber<T>
{
output[index] = GeZeroIfBigger(input[index], maxValue);
}

private void TestGenericMathKernel<T>(T[] inputValues, T[] expected, T maxValue)
where T : unmanaged, INumber<T>
{
using var input = Accelerator.Allocate1D<T>(inputValues);
using var output = Accelerator.Allocate1D<T>(Length);

using var start = Accelerator.DefaultStream.AddProfilingMarker();
Accelerator.LaunchAutoGrouped<
Index1D,
ArrayView1D<T, Stride1D.Dense>,
ArrayView1D<T, Stride1D.Dense>,
T>(
GenericMathKernel,
Accelerator.DefaultStream,
(int)input.Length,
input.View,
output.View,
maxValue);

Verify(output.View, expected);
}

[Fact]
public void GenericMathIntTest()
{
const int MaxValue = 50;
var input = Enumerable.Range(0, Length).ToArray();

var expected = input
.Select(x => GeZeroIfBigger(x, MaxValue))
.ToArray();

TestGenericMathKernel(input, expected, MaxValue);
}

[Fact]
public void GenericMathDoubleTest()
{
const double MaxValue = 75.0;
var input = Enumerable.Range(0, Length)
.Select(x => (double)x)
.ToArray();

var expected = input
.Select(x => GeZeroIfBigger(x, MaxValue))
.ToArray();

TestGenericMathKernel(input, expected, MaxValue);
}

#endif
}
}
3 changes: 2 additions & 1 deletion Src/ILGPU/Frontend/CodeGenerator/CodeGenerator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,8 @@ private void SetupVariables()
}

// Initialize locals
var localVariables = Method.GetMethodBody().LocalVariables;
var methodBody = Disassembler.ExtractMethodBody(Method);
var localVariables = methodBody.LocalVariables;
for (int i = 0, e = localVariables.Count; i < e; ++i)
{
var variable = localVariables[i];
Expand Down
60 changes: 57 additions & 3 deletions Src/ILGPU/Frontend/Disassembler.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
// ---------------------------------------------------------------------------------------
// ILGPU
// Copyright (c) 2018-2021 ILGPU Project
// Copyright (c) 2018-2022 ILGPU Project
// www.ilgpu.net
//
// File: Disassembler.cs
Expand All @@ -15,6 +15,7 @@
using ILGPU.Util;
using System;
using System.Collections.Immutable;
using System.Diagnostics.CodeAnalysis;
using System.Reflection;
using System.Reflection.Emit;
using System.Runtime.CompilerServices;
Expand All @@ -27,7 +28,60 @@ namespace ILGPU.Frontend
/// <remarks>Members of this class are not thread safe.</remarks>
public sealed partial class Disassembler : ILocation
{
#region Constants
#region Static

/// <summary>
/// Indicates whether the .NET runtime has support for Static Abstract methods
/// in Interfaces.
/// </summary>
[SuppressMessage("Performance", "CA1802:Use literals where appropriate")]
private static readonly bool UsingStaticAbstractMethodsInInterfaces =
#if NET7_0_OR_GREATER
RuntimeFeature.IsSupported(RuntimeFeature.VirtualStaticsInInterfaces);
#else
false;
#endif

/// <summary>
/// Extracts the method body for the given method.
/// </summary>
/// <param name="method">The method.</param>
/// <returns>The method body.</returns>
public static MethodBody ExtractMethodBody(MethodBase method)
{
var methodBody = method.GetMethodBody();
if (methodBody == null &&
UsingStaticAbstractMethodsInInterfaces &&
method.DeclaringType.IsInterface &&
method.IsStatic &&
method.IsAbstract)
{
// Support for Static Abstract methods in Interfaces was introduced in
// C# 11, in particular, adding support for Generic Math. The interface
// itself does not contain the method implementation itself. Instead, we
// need to find the concrete type that implements the interface, and find
// the matching method.
var concreteType = method.DeclaringType.GetGenericArguments()[0];
var interfaceMap = concreteType.GetInterfaceMap(method.DeclaringType);

for (int i = 0; i < interfaceMap.InterfaceMethods.Length; i++)
{
if (interfaceMap.InterfaceMethods[i].Name.Equals(
method.Name,
StringComparison.OrdinalIgnoreCase))
{
methodBody = interfaceMap.TargetMethods[i].GetMethodBody();
break;
}
}
}

return methodBody;
}

#endregion

#region Constants

/// <summary>
/// Represents the native pointer type that is used during the
Expand Down Expand Up @@ -98,7 +152,7 @@ public Disassembler(
? MethodBase.GetGenericArguments()
: Array.Empty<Type>();
TypeGenericArguments = MethodBase.DeclaringType.GetGenericArguments();
MethodBody = MethodBase.GetMethodBody();
MethodBody = ExtractMethodBody(MethodBase);
if (MethodBody == null)
{
throw new NotSupportedException(string.Format(
Expand Down

0 comments on commit 7bf5578

Please sign in to comment.