Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support non-ascii in fgVNBasedIntrinsicExpansionForCall_ReadUtf8 #89383

Merged
merged 10 commits into from
Jul 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 46 additions & 38 deletions src/coreclr/jit/helperexpansion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// The .NET Foundation licenses this file to you under the MIT license.

#include "jitpch.h"
#include <minipal/utf8.h>
#ifdef _MSC_VER
#pragma hdrstop
#endif
Expand Down Expand Up @@ -1238,7 +1239,7 @@ bool Compiler::fgVNBasedIntrinsicExpansionForCall(BasicBlock** pBlock, Statement

//------------------------------------------------------------------------------
// fgVNBasedIntrinsicExpansionForCall_ReadUtf8 : Expand NI_System_Text_UTF8Encoding_UTF8EncodingSealed_ReadUtf8
// when src data is a string literal (UTF16) that can be narrowed to ASCII (UTF8), e.g.:
// when src data is a string literal (UTF16) that can be converted to UTF8, e.g.:
//
// string str = "Hello, world!";
// int bytesWritten = ReadUtf8(ref str[0], str.Length, buffer, buffer.Length);
Expand Down Expand Up @@ -1282,59 +1283,66 @@ bool Compiler::fgVNBasedIntrinsicExpansionForCall_ReadUtf8(BasicBlock** pBlock,
return false;
}

assert(strObj != nullptr);

// We mostly expect string literal objects here, but let's be more agile just in case
if (!info.compCompHnd->isObjectImmutable(strObj))
{
JITDUMP("ReadUtf8: srcPtr is not immutable (not a frozen string object?)\n")
return false;
}

GenTree* srcLen = call->gtArgs.GetUserArgByIndex(1)->GetNode();
const GenTree* srcLen = call->gtArgs.GetUserArgByIndex(1)->GetNode();
if (!srcLen->gtVNPair.BothEqual() || !vnStore->IsVNInt32Constant(srcLen->gtVNPair.GetLiberal()))
{
JITDUMP("ReadUtf8: srcLen is not constant\n")
return false;
}

const int MaxPossibleUnrollThreshold = 256;
const unsigned unrollThreshold = min(getUnrollThreshold(UnrollKind::Memcpy), MaxPossibleUnrollThreshold);
const unsigned srcLenCns = (unsigned)vnStore->GetConstantInt32(srcLen->gtVNPair.GetLiberal());
if ((srcLenCns == 0) || (srcLenCns > unrollThreshold))
// Source UTF16 (U16) string length in characters
const unsigned srcLenCnsU16 = (unsigned)vnStore->GetConstantInt32(srcLen->gtVNPair.GetLiberal());
const int MaxU16BufferSizeInChars = 256;
if ((srcLenCnsU16 == 0) || (srcLenCnsU16 > MaxU16BufferSizeInChars))
EgorBo marked this conversation as resolved.
Show resolved Hide resolved
{
// TODO: handle srcLenCns == 0 if it's a common case
JITDUMP("ReadUtf8: srcLenCns is out of unrollable range\n")
JITDUMP("ReadUtf8: srcLenCns is 0 or > MaxPossibleUnrollThreshold\n")
return false;
}

// Read the string literal (UTF16) into a local buffer (UTF8)
assert(strObj != nullptr);
uint16_t bufferU16[MaxPossibleUnrollThreshold];
uint8_t bufferU8[MaxPossibleUnrollThreshold]; // twice smaller because of narrowing

// Both must be within [0..INT_MAX] range as we're going to cast them to int
assert((unsigned)srcLenCns <= INT_MAX);
assert((unsigned)strObjOffset <= INT_MAX);
uint16_t bufferU16[MaxU16BufferSizeInChars];

// getObjectContent is expected to validate the offset and length
if (!info.compCompHnd->getObjectContent(strObj, (uint8_t*)bufferU16, (int)srcLenCns * 2, (int)strObjOffset))
// NOTE: (int) casts should not overflow:
// * srcLenCns is <= MaxUTF16BufferSizeInChars
// * strObjOffset is already checked to be <= INT_MAX
if (!info.compCompHnd->getObjectContent(strObj, (uint8_t*)bufferU16, (int)(srcLenCnsU16 * sizeof(uint16_t)),
(int)strObjOffset))
{
JITDUMP("ReadUtf8: getObjectContent returned false.\n")
return false;
}

for (unsigned charIndex = 0; charIndex < srcLenCns; charIndex++)
const int MaxU8BufferSizeInBytes = 256;
uint8_t bufferU8[MaxU8BufferSizeInBytes];

const int srcLenU8 = (int)minipal_convert_utf16_to_utf8((const CHAR16_T*)bufferU16, srcLenCnsU16, (char*)bufferU8,
MaxU8BufferSizeInBytes, 0);
if (srcLenU8 <= 0)
{
// Buffer keeps the original utf16 chars
uint16_t ch = bufferU16[charIndex];
if (ch > 127)
{
// Only ASCII is supported.
JITDUMP("ReadUtf8: %dth char is not ASCII.\n", charIndex)
return false;
}
// E.g. output buffer is too small
JITDUMP("ReadUtf8: minipal_convert_utf16_to_utf8 returned <= 0\n")
return false;
}

// Narrow U16 to U8 in the same buffer
bufferU8[charIndex] = (uint8_t)ch;
// The API is expected to return [1..MaxU8BufferSizeInBytes] real length of the UTF-8 value
// stored in bufferU8
assert((unsigned)srcLenU8 <= MaxU8BufferSizeInBytes);

// Now that we know the exact UTF8 buffer length we can check if it's unrollable
if (srcLenU8 > (int)getUnrollThreshold(UnrollKind::Memcpy))
{
JITDUMP("ReadUtf8: srcLenU8 is out of unrollable range\n")
return false;
}

DebugInfo debugInfo = stmt->GetDebugInfo();
Expand Down Expand Up @@ -1373,23 +1381,23 @@ bool Compiler::fgVNBasedIntrinsicExpansionForCall_ReadUtf8(BasicBlock** pBlock,
fgMorphStmtBlockOps(block, stmt);
gtUpdateStmtSideEffects(stmt);

// srcLenCns is the length of the string literal in chars (UTF16)
// srcLenU8 is the length of the string literal in chars (UTF16)
// but we're going to use the same value as the "bytesWritten" result in the fast path and in the length check.
GenTree* srcLenCnsNode = gtNewIconNode(srcLenCns);
fgValueNumberTreeConst(srcLenCnsNode);
GenTree* srcLenU8Node = gtNewIconNode(srcLenU8);
fgValueNumberTreeConst(srcLenU8Node);

// We're going to insert the following blocks:
//
// prevBb:
//
// lengthCheckBb:
// bytesWritten = -1;
// if (dstLen <srcLen)
// if (dstLen < srcLenU8)
// goto block;
//
// fastpathBb:
// <unrolled block copy>
// bytesWritten = srcLenCns * 2;
// bytesWritten = srcLenU8;
//
// block:
// use(bytesWritten)
Expand All @@ -1406,7 +1414,7 @@ bool Compiler::fgVNBasedIntrinsicExpansionForCall_ReadUtf8(BasicBlock** pBlock,
fgInsertStmtAtEnd(lengthCheckBb, fgNewStmtFromTree(bytesWrittenDefaultVal, debugInfo));

GenTree* dstLen = call->gtArgs.GetUserArgByIndex(3)->GetNode();
GenTree* lengthCheck = gtNewOperNode(GT_LT, TYP_INT, gtCloneExpr(dstLen), srcLenCnsNode);
GenTree* lengthCheck = gtNewOperNode(GT_LT, TYP_INT, gtCloneExpr(dstLen), srcLenU8Node);
lengthCheck->gtFlags |= GTF_RELOP_JMP_USED;
Statement* lengthCheckStmt = fgNewStmtFromTree(gtNewOperNode(GT_JTRUE, TYP_VOID, lengthCheck), debugInfo);
fgInsertStmtAtEnd(lengthCheckBb, lengthCheckStmt);
Expand All @@ -1424,14 +1432,14 @@ bool Compiler::fgVNBasedIntrinsicExpansionForCall_ReadUtf8(BasicBlock** pBlock,
fastpathBb->bbFlags |= BBF_INTERNAL;

// The widest type we can use for loads
const var_types maxLoadType = roundDownMaxType(srcLenCns);
const var_types maxLoadType = roundDownMaxType(srcLenU8);
assert(genTypeSize(maxLoadType) > 0);

// How many iterations we need to copy UTF8 const data to the destination
unsigned iterations = srcLenCns / genTypeSize(maxLoadType);
unsigned iterations = srcLenU8 / genTypeSize(maxLoadType);

// Add one more iteration if we have a remainder
iterations += (srcLenCns % genTypeSize(maxLoadType) == 0) ? 0 : 1;
iterations += (srcLenU8 % genTypeSize(maxLoadType) == 0) ? 0 : 1;

GenTree* dstPtr = call->gtArgs.GetUserArgByIndex(2)->GetNode();
for (unsigned i = 0; i < iterations; i++)
Expand All @@ -1441,7 +1449,7 @@ bool Compiler::fgVNBasedIntrinsicExpansionForCall_ReadUtf8(BasicBlock** pBlock,
// Last iteration: overlap with previous load if needed
if (i == iterations - 1)
{
offset = (ssize_t)srcLenCns - genTypeSize(maxLoadType);
offset = (ssize_t)srcLenU8 - genTypeSize(maxLoadType);
}

// We're going to emit the following tree (in case of SIMD16 load):
Expand All @@ -1465,7 +1473,7 @@ bool Compiler::fgVNBasedIntrinsicExpansionForCall_ReadUtf8(BasicBlock** pBlock,
}

// Finally, store the number of bytes written to the resultLcl local
Statement* finalStmt = fgNewStmtFromTree(gtNewStoreLclVarNode(resultLclNum, gtCloneExpr(srcLenCnsNode)), debugInfo);
Statement* finalStmt = fgNewStmtFromTree(gtNewStoreLclVarNode(resultLclNum, gtCloneExpr(srcLenU8Node)), debugInfo);
fgInsertStmtAtEnd(fastpathBb, finalStmt);
fastpathBb->bbCodeOffs = block->bbCodeOffsEnd;
fastpathBb->bbCodeOffsEnd = block->bbCodeOffsEnd;
Expand Down
1 change: 1 addition & 0 deletions src/coreclr/minipal/Windows/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
set(SOURCES
doublemapping.cpp
dn-u16.cpp
${CLR_SRC_NATIVE_DIR}/minipal/utf8.c
)

if(NOT CLR_CROSS_COMPONENTS_BUILD)
Expand Down
4 changes: 4 additions & 0 deletions src/coreclr/pal/inc/rt/cpp/stdbool.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

#include "palrt.h"
87 changes: 87 additions & 0 deletions src/tests/JIT/opt/Vectorization/ReadUtf8.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
using System.Numerics;
using System.Runtime.CompilerServices;
using System.Text;
using System.Text.Unicode;
using System.Threading;
using Xunit;

Expand All @@ -20,6 +21,11 @@ public static int TestEntryPoint()
Test_hello();
Test_CJK();
Test_SIMD();
Test_1();
Test_2();
Test_3();
Test_4();
Test_5();
Thread.Sleep(10);
}
return 100;
Expand Down Expand Up @@ -242,4 +248,85 @@ static void IsEmpty(Span<byte> span)
throw new Exception($"{item} != 0");
}
}

// ReadUtf8 is used inside Utf8.TryWrite + interpolation syntax:

[MethodImpl(MethodImplOptions.NoInlining)]
static void Test_1()
{
var buffer = new byte[1024];
ValidateResult("", Utf8.TryWrite(buffer, $"", out var written1), buffer, written1);
ValidateResult("1", Utf8.TryWrite(buffer, $"1", out var written2), buffer, written2);
ValidateResult("12", Utf8.TryWrite(buffer, $"12", out var written3), buffer, written3);
ValidateResult("123", Utf8.TryWrite(buffer, $"123", out var written4), buffer, written4);
ValidateResult("1234", Utf8.TryWrite(buffer, $"1234", out var written5), buffer, written5);
}

[MethodImpl(MethodImplOptions.NoInlining)]
static void Test_2()
{
var buffer = new byte[1024];
ValidateResult("12345", Utf8.TryWrite(buffer, $"12345", out var written1), buffer, written1);
ValidateResult("123456", Utf8.TryWrite(buffer, $"123456", out var written2), buffer, written2);
ValidateResult("1234567", Utf8.TryWrite(buffer, $"1234567", out var written3), buffer, written3);
ValidateResult("12345678", Utf8.TryWrite(buffer, $"12345678", out var written4), buffer, written4);
ValidateResult("123456789", Utf8.TryWrite(buffer, $"123456789", out var written5), buffer, written5);
}

[MethodImpl(MethodImplOptions.NoInlining)]
static void Test_3()
{
var buffer = new byte[1024];
ValidateResult("123456789A", Utf8.TryWrite(buffer, $"123456789A", out var written1), buffer, written1);
ValidateResult("123456789AB", Utf8.TryWrite(buffer, $"123456789AB", out var written2), buffer, written2);
ValidateResult("123456789ABC", Utf8.TryWrite(buffer, $"123456789ABC", out var written3), buffer, written3);
ValidateResult("123456789ABCD", Utf8.TryWrite(buffer, $"123456789ABCD", out var written4), buffer, written4);
ValidateResult("123456789ABCDE", Utf8.TryWrite(buffer, $"123456789ABCDE", out var written5), buffer, written5);
}

[MethodImpl(MethodImplOptions.NoInlining)]
static void Test_4()
{
var buffer = new byte[1024];
ValidateResult("123456789ABCDEF", Utf8.TryWrite(buffer, $"123456789ABCDEF", out var written1), buffer, written1);
ValidateResult("123456789ABCDEF\u0419", Utf8.TryWrite(buffer, $"123456789ABCDEF\u0419", out var written2), buffer, written2);
ValidateResult("123456789ABCDEF\u0419\u044C", Utf8.TryWrite(buffer, $"123456789ABCDEF\u0419\u044C", out var written3), buffer, written3);
ValidateResult("123456789ABCDEF\u0419\u044Cf", Utf8.TryWrite(buffer, $"123456789ABCDEF\u0419\u044Cf", out var written4), buffer, written4);
ValidateResult("123456789ABCDEF\u0419\u044Cf.", Utf8.TryWrite(buffer, $"123456789ABCDEF\u0419\u044Cf.", out var written5), buffer, written5);
}

[MethodImpl(MethodImplOptions.NoInlining)]
static void Test_5()
{
var buffer = new byte[1024];
ValidateResult("\uD800b", Utf8.TryWrite(buffer, $"\uD800b", out var written1), buffer, written1);
ValidateResult("1\uD800b", Utf8.TryWrite(buffer, $"1\uD800b", out var written2), buffer, written2);
ValidateResult("11\uD800b", Utf8.TryWrite(buffer, $"11\uD800b", out var written3), buffer, written3);
ValidateResult("\uD800b\uD800b", Utf8.TryWrite(buffer, $"\uD800b\uD800b", out var written4), buffer, written4);
ValidateResult("\uD800b435345435", Utf8.TryWrite(buffer, $"\uD800b435345435", out var written5), buffer, written5);
ValidateResult("342532523\uD800b\uD800b35235", Utf8.TryWrite(buffer, $"342532523\uD800b\uD800b35235", out var written6), buffer, written6);
ValidateResult("efewfwfwfwfwefwe\uD800bfewfw\uD800bwfwefew\uD800b", Utf8.TryWrite(buffer, $"efewfwfwfwfwefwe\uD800bfewfw\uD800bwfwefew\uD800b", out var written7), buffer, written7);
}

[MethodImpl(MethodImplOptions.NoInlining)]
static void ValidateResult(string str, bool actualResult, byte[] actualData, int actualBytesWritten)
{
byte[] expectedData = new byte[actualData.Length];
bool expectedResult = Utf8.TryWrite(expectedData, $"{str}", out int expectedBytesWritten);
if (expectedResult != actualResult)
{
throw new Exception($"Unexpected return value: {actualResult}");
}

if (actualBytesWritten != expectedBytesWritten)
{
throw new Exception($"bytesWritten value: {actualBytesWritten} != {expectedBytesWritten}");
}

if (expectedResult && !actualData.AsSpan(0, actualBytesWritten).SequenceEqual(
expectedData.AsSpan(0, expectedBytesWritten)))
{
throw new Exception("actualData != expectedData");
}
}
}