Skip to content

Commit

Permalink
[wasm] Correctly escape library names when generating symbols for .c (d…
Browse files Browse the repository at this point in the history
…otnet#79007)

* [wasm] Correctly escape library names when generating symbols for .c files.
Use the existing `FixupSymbolName` method for fixing library names too,
when converting to symbols.

* [wasm] *TableGenerator task: Cache the symbol name fixups
.. as it is called frequently, and for repeated strings. For a
consolewasm template build, we get 490 calls but only 140 of them are
for unique strings.

* Add tests

Fixes dotnet#78992 .
  • Loading branch information
radical authored Nov 30, 2022
1 parent c01ad86 commit 85a9dfc
Show file tree
Hide file tree
Showing 4 changed files with 130 additions and 67 deletions.
67 changes: 67 additions & 0 deletions src/mono/wasm/Wasm.Build.Tests/PInvokeTableGeneratorTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Text;
using Xunit;
using Xunit.Abstractions;

Expand Down Expand Up @@ -551,6 +552,72 @@ public void BuildNativeInNonEnglishCulture(BuildArgs buildArgs, string culture,
Assert.Contains("square: 25", output);
}

[Theory]
[BuildAndRun(host: RunHost.Chrome, parameters: new object[] { new object[] {
"with-hyphen",
"with#hash-and-hyphen",
"with.per.iod",
"with🚀unicode#"
} })]

public void CallIntoLibrariesWithNonAlphanumericCharactersInTheirNames(BuildArgs buildArgs, string[] libraryNames, RunHost host, string id)
{
buildArgs = ExpandBuildArgs(buildArgs,
extraItems: @$"<NativeFileReference Include=""*.c"" />",
extraProperties: buildArgs.AOT
? string.Empty
: "<WasmBuildNative>true</WasmBuildNative>");

int baseArg = 10;
(_, string output) = BuildProject(buildArgs,
id: id,
new BuildProjectOptions(
InitProject: () => GenerateSourceFiles(_projectDir!, baseArg),
Publish: buildArgs.AOT,
DotnetWasmFromRuntimePack: false
));

output = RunAndTestWasmApp(buildArgs,
buildDir: _projectDir,
expectedExitCode: 42,
host: host,
id: id);

for (int i = 0; i < libraryNames.Length; i ++)
{
Assert.Contains($"square_{i}: {(i + baseArg) * (i + baseArg)}", output);
}

void GenerateSourceFiles(string outputPath, int baseArg)
{
StringBuilder csBuilder = new($@"
using System;
using System.Runtime.InteropServices;
");

StringBuilder dllImportsBuilder = new();
for (int i = 0; i < libraryNames.Length; i ++)
{
dllImportsBuilder.AppendLine($"[DllImport(\"{libraryNames[i]}\")] static extern int square_{i}(int x);");
csBuilder.AppendLine($@"Console.WriteLine($""square_{i}: {{square_{i}({i + baseArg})}}"");");

string nativeCode = $@"
#include <stdarg.h>
int square_{i}(int x)
{{
return x * x;
}}";
File.WriteAllText(Path.Combine(outputPath, $"{libraryNames[i]}.c"), nativeCode);
}

csBuilder.AppendLine("return 42;");
csBuilder.Append(dllImportsBuilder);

File.WriteAllText(Path.Combine(outputPath, "Program.cs"), csBuilder.ToString());
}
}

private (BuildArgs, string) BuildForVariadicFunctionTests(string programText, BuildArgs buildArgs, string id, string? verbosity = null, string extraProperties = "")
{
extraProperties += "<AllowUnsafeBlocks>true</AllowUnsafeBlocks><_WasmDevel>true</_WasmDevel>";
Expand Down
12 changes: 7 additions & 5 deletions src/tasks/WasmAppBuilder/IcallTableGenerator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,6 @@

using System;
using System.Collections.Generic;
using System.Collections.Immutable;
using System.Diagnostics;
using System.Diagnostics.CodeAnalysis;
using System.IO;
using System.Linq;
using System.Text;
Expand All @@ -23,8 +20,13 @@ internal sealed class IcallTableGenerator
private Dictionary<string, IcallClass> _runtimeIcalls = new Dictionary<string, IcallClass>();

private TaskLoggingHelper Log { get; set; }
private readonly Func<string, string> _fixupSymbolName;

public IcallTableGenerator(TaskLoggingHelper log) => Log = log;
public IcallTableGenerator(Func<string, string> fixupSymbolName, TaskLoggingHelper log)
{
Log = log;
_fixupSymbolName = fixupSymbolName;
}

//
// Given the runtime generated icall table, and a set of assemblies, generate
Expand Down Expand Up @@ -86,7 +88,7 @@ private void EmitTable(StreamWriter w)
if (assembly == "System.Private.CoreLib")
aname = "corlib";
else
aname = assembly.Replace(".", "_");
aname = _fixupSymbolName(assembly);
w.WriteLine($"#define ICALL_TABLE_{aname} 1\n");

w.WriteLine($"static int {aname}_icall_indexes [] = {{");
Expand Down
50 changes: 40 additions & 10 deletions src/tasks/WasmAppBuilder/ManagedToNativeGenerator.cs
Original file line number Diff line number Diff line change
@@ -1,21 +1,13 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System;
using System.Collections.Generic;
using System.Collections.Immutable;
using System.Diagnostics;
using System.Diagnostics.CodeAnalysis;
using System.IO;
using System.Linq;
using System.Text;
using System.Text.Json;
using System.Reflection;
using Microsoft.Build.Framework;
using Microsoft.Build.Utilities;

#nullable enable

public class ManagedToNativeGenerator : Task
{
[Required]
Expand All @@ -37,6 +29,11 @@ public class ManagedToNativeGenerator : Task
[Output]
public string[]? FileWrites { get; private set; }

private static readonly char[] s_charsToReplace = new[] { '.', '-', '+' };

// Avoid sharing this cache with all the invocations of this task throughout the build
private readonly Dictionary<string, string> _symbolNameFixups = new();

public override bool Execute()
{
if (Assemblies!.Length == 0)
Expand Down Expand Up @@ -65,8 +62,8 @@ public override bool Execute()

private void ExecuteInternal()
{
var pinvoke = new PInvokeTableGenerator(Log);
var icall = new IcallTableGenerator(Log);
var pinvoke = new PInvokeTableGenerator(FixupSymbolName, Log);
var icall = new IcallTableGenerator(FixupSymbolName, Log);

IEnumerable<string> cookies = Enumerable.Concat(
pinvoke.Generate(PInvokeModules, Assemblies!, PInvokeOutputPath!),
Expand All @@ -80,4 +77,37 @@ private void ExecuteInternal()
? new string[] { PInvokeOutputPath, IcallOutputPath, InterpToNativeOutputPath }
: new string[] { PInvokeOutputPath, InterpToNativeOutputPath };
}

public string FixupSymbolName(string name)
{
if (_symbolNameFixups.TryGetValue(name, out string? fixedName))
return fixedName;

UTF8Encoding utf8 = new();
byte[] bytes = utf8.GetBytes(name);
StringBuilder sb = new();

foreach (byte b in bytes)
{
if ((b >= (byte)'0' && b <= (byte)'9') ||
(b >= (byte)'a' && b <= (byte)'z') ||
(b >= (byte)'A' && b <= (byte)'Z') ||
(b == (byte)'_'))
{
sb.Append((char)b);
}
else if (s_charsToReplace.Contains((char)b))
{
sb.Append('_');
}
else
{
sb.Append($"_{b:X}_");
}
}

fixedName = sb.ToString();
_symbolNameFixups[name] = fixedName;
return fixedName;
}
}
68 changes: 16 additions & 52 deletions src/tasks/WasmAppBuilder/PInvokeTableGenerator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,16 @@

internal sealed class PInvokeTableGenerator
{
private static readonly char[] s_charsToReplace = new[] { '.', '-', '+' };
private readonly Dictionary<Assembly, bool> _assemblyDisableRuntimeMarshallingAttributeCache = new();

private TaskLoggingHelper Log { get; set; }
private readonly Func<string, string> _fixupSymbolName;

public PInvokeTableGenerator(TaskLoggingHelper log) => Log = log;
public PInvokeTableGenerator(Func<string, string> fixupSymbolName, TaskLoggingHelper log)
{
Log = log;
_fixupSymbolName = fixupSymbolName;
}

public IEnumerable<string> Generate(string[] pinvokeModules, string[] assemblies, string outputPath)
{
Expand Down Expand Up @@ -234,14 +238,14 @@ private void EmitPInvokeTable(StreamWriter w, Dictionary<string, string> modules

foreach (var module in modules.Keys)
{
string symbol = ModuleNameToId(module) + "_imports";
string symbol = _fixupSymbolName(module) + "_imports";
w.WriteLine("static PinvokeImport " + symbol + " [] = {");

var assemblies_pinvokes = pinvokes.
Where(l => l.Module == module && !l.Skip).
OrderBy(l => l.EntryPoint).
GroupBy(d => d.EntryPoint).
Select(l => "{\"" + FixupSymbolName(l.Key) + "\", " + FixupSymbolName(l.Key) + "}, " +
Select(l => "{\"" + _fixupSymbolName(l.Key) + "\", " + _fixupSymbolName(l.Key) + "}, " +
"// " + string.Join(", ", l.Select(c => c.Method.DeclaringType!.Module!.Assembly!.GetName()!.Name!).Distinct().OrderBy(n => n)));

foreach (var pinvoke in assemblies_pinvokes)
Expand All @@ -255,7 +259,7 @@ private void EmitPInvokeTable(StreamWriter w, Dictionary<string, string> modules
w.Write("static void *pinvoke_tables[] = { ");
foreach (var module in modules.Keys)
{
string symbol = ModuleNameToId(module) + "_imports";
string symbol = _fixupSymbolName(module) + "_imports";
w.Write(symbol + ",");
}
w.WriteLine("};");
Expand All @@ -266,18 +270,6 @@ private void EmitPInvokeTable(StreamWriter w, Dictionary<string, string> modules
}
w.WriteLine("};");

static string ModuleNameToId(string name)
{
if (name.IndexOfAny(s_charsToReplace) < 0)
return name;

string fixedName = name;
foreach (char c in s_charsToReplace)
fixedName = fixedName.Replace(c, '_');

return fixedName;
}

static bool ShouldTreatAsVariadic(PInvoke[] candidates)
{
if (candidates.Length < 2)
Expand All @@ -295,43 +287,15 @@ static bool ShouldTreatAsVariadic(PInvoke[] candidates)
}
}

private static string FixupSymbolName(string name)
{
UTF8Encoding utf8 = new();
byte[] bytes = utf8.GetBytes(name);
StringBuilder sb = new();

foreach (byte b in bytes)
{
if ((b >= (byte)'0' && b <= (byte)'9') ||
(b >= (byte)'a' && b <= (byte)'z') ||
(b >= (byte)'A' && b <= (byte)'Z') ||
(b == (byte)'_'))
{
sb.Append((char)b);
}
else if (s_charsToReplace.Contains((char)b))
{
sb.Append('_');
}
else
{
sb.Append($"_{b:X}_");
}
}

return sb.ToString();
}

private static string SymbolNameForMethod(MethodInfo method)
private string SymbolNameForMethod(MethodInfo method)
{
StringBuilder sb = new();
Type? type = method.DeclaringType;
sb.Append($"{type!.Module!.Assembly!.GetName()!.Name!}_");
sb.Append($"{(type!.IsNested ? type!.FullName : type!.Name)}_");
sb.Append(method.Name);

return FixupSymbolName(sb.ToString());
return _fixupSymbolName(sb.ToString());
}

private static string MapType(Type t) => t.Name switch
Expand Down Expand Up @@ -374,7 +338,7 @@ private static bool TryIsMethodGetParametersUnsupported(MethodInfo method, [NotN
{
// FIXME: System.Reflection.MetadataLoadContext can't decode function pointer types
// https://github.com/dotnet/runtime/issues/43791
sb.Append($"int {FixupSymbolName(pinvoke.EntryPoint)} (int, int, int, int, int);");
sb.Append($"int {_fixupSymbolName(pinvoke.EntryPoint)} (int, int, int, int, int);");
return sb.ToString();
}

Expand All @@ -390,7 +354,7 @@ private static bool TryIsMethodGetParametersUnsupported(MethodInfo method, [NotN
}

sb.Append(MapType(method.ReturnType));
sb.Append($" {FixupSymbolName(pinvoke.EntryPoint)} (");
sb.Append($" {_fixupSymbolName(pinvoke.EntryPoint)} (");
int pindex = 0;
var pars = method.GetParameters();
foreach (var p in pars)
Expand All @@ -404,7 +368,7 @@ private static bool TryIsMethodGetParametersUnsupported(MethodInfo method, [NotN
return sb.ToString();
}

private static void EmitNativeToInterp(StreamWriter w, ref List<PInvokeCallback> callbacks)
private void EmitNativeToInterp(StreamWriter w, ref List<PInvokeCallback> callbacks)
{
// Generate native->interp entry functions
// These are called by native code, so they need to obtain
Expand Down Expand Up @@ -450,7 +414,7 @@ private static void EmitNativeToInterp(StreamWriter w, ref List<PInvokeCallback>

bool is_void = method.ReturnType.Name == "Void";

string module_symbol = method.DeclaringType!.Module!.Assembly!.GetName()!.Name!.Replace(".", "_");
string module_symbol = _fixupSymbolName(method.DeclaringType!.Module!.Assembly!.GetName()!.Name!);
uint token = (uint)method.MetadataToken;
string class_name = method.DeclaringType.Name;
string method_name = method.Name;
Expand Down Expand Up @@ -517,7 +481,7 @@ private static void EmitNativeToInterp(StreamWriter w, ref List<PInvokeCallback>
foreach (var cb in callbacks)
{
var method = cb.Method;
string module_symbol = method.DeclaringType!.Module!.Assembly!.GetName()!.Name!.Replace(".", "_");
string module_symbol = _fixupSymbolName(method.DeclaringType!.Module!.Assembly!.GetName()!.Name!);
string class_name = method.DeclaringType.Name;
string method_name = method.Name;
w.WriteLine($"\"{module_symbol}_{class_name}_{method_name}\",");
Expand Down

0 comments on commit 85a9dfc

Please sign in to comment.