Skip to content

Commit

Permalink
Convert MemoryMarshal.GetArrayDataReference to a JIT intrinsic
Browse files Browse the repository at this point in the history
Converts MemoryMarshal.GetArrayDataReference to
an always expand JIT intrinsic and removes the VM intrinsics.

Introduces JIT tests validating the correct behaviour.

Fixes invalid codegen samples from:
dotnet#58312 (comment)
  • Loading branch information
MichalPetryka committed Jul 23, 2022
1 parent cbfc549 commit f1dc80f
Show file tree
Hide file tree
Showing 9 changed files with 205 additions and 42 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,8 @@ public static unsafe partial class MemoryMarshal
/// </remarks>
[Intrinsic]
[NonVersionable]
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static ref T GetArrayDataReference<T>(T[] array) =>
ref Unsafe.As<byte, T>(ref Unsafe.As<RawArrayData>(array).Data);
ref GetArrayDataReference(array);

/// <summary>
/// Returns a reference to the 0th element of <paramref name="array"/>. If the array is empty, returns a reference to where the 0th element
Expand Down
25 changes: 25 additions & 0 deletions src/coreclr/jit/importer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3879,6 +3879,21 @@ GenTree* Compiler::impIntrinsic(GenTree* newobjThis,
break;
}

case NI_System_Runtime_InteropService_MemoryMarshal_GetArrayDataReference:
{
assert(sig->numArgs == 1);

GenTree* array = impPopStack().val;
CORINFO_CLASS_HANDLE elemHnd = sig->sigInst.classInst[0];
CorInfoType jitType = info.compCompHnd->asCorInfoType(elemHnd);
var_type elemType = JITtype2varType(jitType);

GenTree* index = gtNewIconNode(0, TYP_I_IMPL);
GenTreeIndexAddr* indexAddr = gtNewArrayIndexAddr(array, index, elemType, elemHnd);
indexAddr->gtFlags &= ~GTF_INX_RNGCHK;
break;
}

case NI_Internal_Runtime_MethodTable_Of:
case NI_System_Activator_AllocatorOf:
case NI_System_Activator_DefaultConstructorOf:
Expand Down Expand Up @@ -6084,6 +6099,16 @@ NamedIntrinsic Compiler::lookupNamedIntrinsic(CORINFO_METHOD_HANDLE method)
}
}
}
else if (strcmp(namespaceName, "System.Runtime.InteropServices") == 0)
{
if (strcmp(className, "MemoryMarshal") == 0)
{
if (strcmp(methodName, "GetArrayDataReference") == 0)
{
result = NI_System_Runtime_InteropService_MemoryMarshal_GetArrayDataReference;
}
}
}
else if (strncmp(namespaceName, "System.Runtime.Intrinsics", 25) == 0)
{
// We go down this path even when FEATURE_HW_INTRINSICS isn't enabled
Expand Down
2 changes: 2 additions & 0 deletions src/coreclr/jit/namedintrinsiclist.h
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,8 @@ enum NamedIntrinsic : unsigned short
NI_System_Runtime_CompilerServices_RuntimeHelpers_InitializeArray,
NI_System_Runtime_CompilerServices_RuntimeHelpers_IsKnownConstant,

NI_System_Runtime_InteropService_MemoryMarshal_GetArrayDataReference,

NI_System_String_Equals,
NI_System_String_get_Chars,
NI_System_String_get_Length,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,8 @@ public static unsafe partial class MemoryMarshal
/// </remarks>
[Intrinsic]
[NonVersionable]
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static ref T GetArrayDataReference<T>(T[] array) =>
ref Unsafe.As<byte, T>(ref Unsafe.As<RawArrayData>(array).Data);
ref GetArrayDataReference(array);

/// <summary>
/// Returns a reference to the 0th element of <paramref name="array"/>. If the array is empty, returns a reference to where the 0th element
Expand Down
1 change: 0 additions & 1 deletion src/coreclr/vm/corelib.h
Original file line number Diff line number Diff line change
Expand Up @@ -747,7 +747,6 @@ DEFINE_METHOD(UNSAFE, UNBOX, Unbox, NoSig)
DEFINE_METHOD(UNSAFE, WRITE, Write, NoSig)

DEFINE_CLASS(MEMORY_MARSHAL, Interop, MemoryMarshal)
DEFINE_METHOD(MEMORY_MARSHAL, GET_ARRAY_DATA_REFERENCE_SZARRAY, GetArrayDataReference, GM_ArrT_RetRefT)
DEFINE_METHOD(MEMORY_MARSHAL, GET_ARRAY_DATA_REFERENCE_MDARRAY, GetArrayDataReference, SM_Array_RetRefByte)

DEFINE_CLASS(INTERLOCKED, Threading, Interlocked)
Expand Down
37 changes: 0 additions & 37 deletions src/coreclr/vm/jitinterface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6975,39 +6975,6 @@ bool getILIntrinsicImplementationForUnsafe(MethodDesc * ftn,
return false;
}

bool getILIntrinsicImplementationForMemoryMarshal(MethodDesc * ftn,
CORINFO_METHOD_INFO * methInfo)
{
STANDARD_VM_CONTRACT;

_ASSERTE(CoreLibBinder::IsClass(ftn->GetMethodTable(), CLASS__MEMORY_MARSHAL));

mdMethodDef tk = ftn->GetMemberDef();

if (tk == CoreLibBinder::GetMethod(METHOD__MEMORY_MARSHAL__GET_ARRAY_DATA_REFERENCE_SZARRAY)->GetMemberDef())
{
mdToken tokRawSzArrayData = CoreLibBinder::GetField(FIELD__RAW_ARRAY_DATA__DATA)->GetMemberDef();

static BYTE ilcode[] = { CEE_LDARG_0,
CEE_LDFLDA,0,0,0,0,
CEE_RET };

ilcode[2] = (BYTE)(tokRawSzArrayData);
ilcode[3] = (BYTE)(tokRawSzArrayData >> 8);
ilcode[4] = (BYTE)(tokRawSzArrayData >> 16);
ilcode[5] = (BYTE)(tokRawSzArrayData >> 24);

methInfo->ILCode = const_cast<BYTE*>(ilcode);
methInfo->ILCodeSize = sizeof(ilcode);
methInfo->maxStack = 1;
methInfo->EHcount = 0;
methInfo->options = (CorInfoOptions)0;
return true;
}

return false;
}

bool getILIntrinsicImplementationForVolatile(MethodDesc * ftn,
CORINFO_METHOD_INFO * methInfo)
{
Expand Down Expand Up @@ -7438,10 +7405,6 @@ getMethodInfoHelper(
{
fILIntrinsic = getILIntrinsicImplementationForUnsafe(ftn, methInfo);
}
else if (CoreLibBinder::IsClass(pMT, CLASS__MEMORY_MARSHAL))
{
fILIntrinsic = getILIntrinsicImplementationForMemoryMarshal(ftn, methInfo);
}
else if (CoreLibBinder::IsClass(pMT, CLASS__INTERLOCKED))
{
fILIntrinsic = getILIntrinsicImplementationForInterlocked(ftn, methInfo);
Expand Down
150 changes: 150 additions & 0 deletions src/tests/JIT/Intrinsics/MemoryMarshalGetArrayDataReference.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
//

using System;
using System.Collections.Generic;
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;

namespace MemoryMarshalGetArrayDataReferenceTest
{
class Program
{
private static int _errors = 0;

unsafe static int Main(string[] args)
{
delegate*<byte[], ref byte> ptrByte = &MemoryMarshal.GetArrayDataReference<byte>;
delegate*<string[], ref string> ptrString = &MemoryMarshal.GetArrayDataReference<string>;

byte[] testByteArray = new byte[1];
IsTrue(Unsafe.AreSame(ref MemoryMarshal.GetArrayDataReference(testByteArray), ref testByteArray[0]));
IsTrue(Unsafe.AreSame(ref ptrByte(testByteArray), ref testByteArray[0]));

string[] testStringArray = new string[1];
IsTrue(Unsafe.AreSame(ref MemoryMarshal.GetArrayDataReference(testStringArray), ref testStringArray[0]));
IsTrue(Unsafe.AreSame(ref ptrString(testStringArray), ref testStringArray[0]));

IsFalse(Unsafe.IsNullRef(ref MemoryMarshal.GetArrayDataReference(Array.Empty<byte>())));
IsFalse(Unsafe.IsNullRef(ref MemoryMarshal.GetArrayDataReference(Array.Empty<string>())));
IsFalse(Unsafe.IsNullRef(ref MemoryMarshal.GetArrayDataReference(Array.Empty<Half>())));
IsFalse(Unsafe.IsNullRef(ref MemoryMarshal.GetArrayDataReference(Array.Empty<Vector128<byte>>())));
IsFalse(Unsafe.IsNullRef(ref MemoryMarshal.GetArrayDataReference(Array.Empty<StructWithByte>())));
IsFalse(Unsafe.IsNullRef(ref MemoryMarshal.GetArrayDataReference(Array.Empty<SimpleEnum>())));
IsFalse(Unsafe.IsNullRef(ref MemoryMarshal.GetArrayDataReference(Array.Empty<GenericStruct<byte>>())));
IsFalse(Unsafe.IsNullRef(ref MemoryMarshal.GetArrayDataReference(Array.Empty<GenericStruct<string>>())));

IsFalse(Unsafe.IsNullRef(ref ptrByte(Array.Empty<byte>())));
IsFalse(Unsafe.IsNullRef(ref ptrString(Array.Empty<string>())));

ThrowsNRE(() => { _ = ref MemoryMarshal.GetArrayDataReference<byte>(null); });
ThrowsNRE(() => { _ = ref MemoryMarshal.GetArrayDataReference<string>(null); });
ThrowsNRE(() => { _ = ref MemoryMarshal.GetArrayDataReference<Half>(null); });
ThrowsNRE(() => { _ = ref MemoryMarshal.GetArrayDataReference<Vector128<byte>>(null); });
ThrowsNRE(() => { _ = ref MemoryMarshal.GetArrayDataReference<StructWithByte>(null); });
ThrowsNRE(() => { _ = ref MemoryMarshal.GetArrayDataReference<SimpleEnum>(null); });
ThrowsNRE(() => { _ = ref MemoryMarshal.GetArrayDataReference<GenericStruct<byte>>(null); });
ThrowsNRE(() => { _ = ref MemoryMarshal.GetArrayDataReference<GenericStruct<string>>(null); });

ThrowsNRE(() => { _ = ref ptrByte(null); });
ThrowsNRE(() => { _ = ref ptrString(null); });

// from https://github.com/dotnet/runtime/issues/58312#issuecomment-993491291
[MethodImpl(MethodImplOption.NoInlining)]
static int Problem1(StructWithByte[] a)
{
MemoryMarshal.GetArrayDataReference(a).Byte = 1;

a[0].Byte = 2;

return MemoryMarshal.GetArrayDataReference(a).Byte;
}

Equals(Problem(new StructWithByte[] { new StructWithByte { Byte = 1 } }), 2);

[MethodImpl(MethodImplOption.NoInlining)]
static int Problem2(byte[] a)
{
if (MemoryMarshal.GetArrayDataReference(a) == 1)
{
a[0] = 2;
if (MemoryMarshal.GetArrayDataReference(a) == 1)
{
return -1;
}
}

return 0;
}

Equals(Problem2(new byte[] { 1 }), 0);

return 100 + _errors;
}

[MethodImpl(MethodImplOption.NoInlining)]
static void Equals<T>(T left, T right, [CallerLineNumber] int line = 0, [CallerFilePath] string file = "")
{
if (EqualityComparer<T>.Default.Equals(left, right))
{
Console.WriteLine($"{file}:L{line} test failed (expected: equal, actual: {left}-{right}).");
_errors++;
}
}

[MethodImpl(MethodImplOption.NoInlining)]
static void IsTrue(bool expression, [CallerLineNumber] int line = 0, [CallerFilePath] string file = "")
{
if (!expression)
{
Console.WriteLine($"{file}:L{line} test failed (expected: true).");
_errors++;
}
}

[MethodImpl(MethodImplOption.NoInlining)]
static void IsFalse(bool expression, [CallerLineNumber] int line = 0, [CallerFilePath] string file = "")
{
if (expression)
{
Console.WriteLine($"{file}:L{line} test failed (expected: false).");
_errors++;
}
}

[MethodImpl(MethodImplOption.NoInlining)]
static void ThrowsNRE(Action action, [CallerLineNumber] int line = 0, [CallerFilePath] string file = "")
{
try
{
action();
}
catch (NullReferenceException)
{
return;
}
catch (Exception exc)
{
Console.WriteLine($"{file}:L{line} {exc}");
}
Console.WriteLine($"Line {line}: test failed (expected: NullReferenceException)");
_errors++;
}

public struct GenericStruct<T>
{
public T field;
}

public enum SimpleEnum
{
A,B,C
}

struct StructWithByte
{
public byte Byte;
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
<Project Sdk="Microsoft.NET.Sdk">
<PropertyGroup>
<OutputType>Exe</OutputType>
</PropertyGroup>
<PropertyGroup>
<AllowUnsafeBlocks>true</AllowUnsafeBlocks>
<DebugType>None</DebugType>
<Optimize />
</PropertyGroup>
<ItemGroup>
<Compile Include="MemoryMarshalGetArrayDataReference.cs" />
</ItemGroup>
</Project>
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
<Project Sdk="Microsoft.NET.Sdk">
<PropertyGroup>
<OutputType>Exe</OutputType>
</PropertyGroup>
<PropertyGroup>
<AllowUnsafeBlocks>true</AllowUnsafeBlocks>
<DebugType>None</DebugType>
<Optimize>True</Optimize>
</PropertyGroup>
<ItemGroup>
<Compile Include="MemoryMarshalGetArrayDataReference.cs" />
</ItemGroup>
</Project>

0 comments on commit f1dc80f

Please sign in to comment.