Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[wasm] Correctly escape library names when generating symbols for .c #79007

Merged
merged 8 commits into from
Nov 30, 2022
67 changes: 67 additions & 0 deletions src/mono/wasm/Wasm.Build.Tests/PInvokeTableGeneratorTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Text;
using Xunit;
using Xunit.Abstractions;

Expand Down Expand Up @@ -551,6 +552,72 @@ public void BuildNativeInNonEnglishCulture(BuildArgs buildArgs, string culture,
Assert.Contains("square: 25", output);
}

[Theory]
[BuildAndRun(host: RunHost.Chrome, parameters: new object[] { new object[] {
"with-hyphen",
"with#hash-and-hyphen",
"with.per.iod",
"with🚀unicode#"
} })]

public void CallIntoLibrariesWithNonAlphanumericCharactersInTheirNames(BuildArgs buildArgs, string[] libraryNames, RunHost host, string id)
{
buildArgs = ExpandBuildArgs(buildArgs,
extraItems: @$"<NativeFileReference Include=""*.c"" />",
extraProperties: buildArgs.AOT
? string.Empty
: "<WasmBuildNative>true</WasmBuildNative>");

int baseArg = 10;
(_, string output) = BuildProject(buildArgs,
id: id,
new BuildProjectOptions(
InitProject: () => GenerateSourceFiles(_projectDir!, baseArg),
Publish: buildArgs.AOT,
DotnetWasmFromRuntimePack: false
));

output = RunAndTestWasmApp(buildArgs,
buildDir: _projectDir,
expectedExitCode: 42,
host: host,
id: id);

for (int i = 0; i < libraryNames.Length; i ++)
{
Assert.Contains($"square_{i}: {(i + baseArg) * (i + baseArg)}", output);
}

void GenerateSourceFiles(string outputPath, int baseArg)
{
StringBuilder csBuilder = new($@"
using System;
using System.Runtime.InteropServices;
");

StringBuilder dllImportsBuilder = new();
for (int i = 0; i < libraryNames.Length; i ++)
{
dllImportsBuilder.AppendLine($"[DllImport(\"{libraryNames[i]}\")] static extern int square_{i}(int x);");
csBuilder.AppendLine($@"Console.WriteLine($""square_{i}: {{square_{i}({i + baseArg})}}"");");

string nativeCode = $@"
#include <stdarg.h>

int square_{i}(int x)
{{
return x * x;
}}";
File.WriteAllText(Path.Combine(outputPath, $"{libraryNames[i]}.c"), nativeCode);
}

csBuilder.AppendLine("return 42;");
csBuilder.Append(dllImportsBuilder);

File.WriteAllText(Path.Combine(outputPath, "Program.cs"), csBuilder.ToString());
}
}

private (BuildArgs, string) BuildForVariadicFunctionTests(string programText, BuildArgs buildArgs, string id, string? verbosity = null, string extraProperties = "")
{
extraProperties += "<AllowUnsafeBlocks>true</AllowUnsafeBlocks><_WasmDevel>true</_WasmDevel>";
Expand Down
12 changes: 7 additions & 5 deletions src/tasks/WasmAppBuilder/IcallTableGenerator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,6 @@

using System;
using System.Collections.Generic;
using System.Collections.Immutable;
using System.Diagnostics;
using System.Diagnostics.CodeAnalysis;
using System.IO;
using System.Linq;
using System.Text;
Expand All @@ -23,8 +20,13 @@ internal sealed class IcallTableGenerator
private Dictionary<string, IcallClass> _runtimeIcalls = new Dictionary<string, IcallClass>();

private TaskLoggingHelper Log { get; set; }
private readonly Func<string, string> _fixupSymbolName;

public IcallTableGenerator(TaskLoggingHelper log) => Log = log;
public IcallTableGenerator(Func<string, string> fixupSymbolName, TaskLoggingHelper log)
{
Log = log;
_fixupSymbolName = fixupSymbolName;
}

//
// Given the runtime generated icall table, and a set of assemblies, generate
Expand Down Expand Up @@ -86,7 +88,7 @@ private void EmitTable(StreamWriter w)
if (assembly == "System.Private.CoreLib")
aname = "corlib";
else
aname = assembly.Replace(".", "_");
aname = _fixupSymbolName(assembly);
w.WriteLine($"#define ICALL_TABLE_{aname} 1\n");

w.WriteLine($"static int {aname}_icall_indexes [] = {{");
Expand Down
50 changes: 40 additions & 10 deletions src/tasks/WasmAppBuilder/ManagedToNativeGenerator.cs
Original file line number Diff line number Diff line change
@@ -1,21 +1,13 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System;
using System.Collections.Generic;
using System.Collections.Immutable;
using System.Diagnostics;
using System.Diagnostics.CodeAnalysis;
using System.IO;
using System.Linq;
using System.Text;
using System.Text.Json;
using System.Reflection;
using Microsoft.Build.Framework;
using Microsoft.Build.Utilities;

#nullable enable

public class ManagedToNativeGenerator : Task
{
[Required]
Expand All @@ -37,6 +29,11 @@ public class ManagedToNativeGenerator : Task
[Output]
public string[]? FileWrites { get; private set; }

private static readonly char[] s_charsToReplace = new[] { '.', '-', '+' };

// Avoid sharing this cache with all the invocations of this task throughout the build
private readonly Dictionary<string, string> _symbolNameFixups = new();

public override bool Execute()
{
if (Assemblies!.Length == 0)
Expand Down Expand Up @@ -65,8 +62,8 @@ public override bool Execute()

private void ExecuteInternal()
{
var pinvoke = new PInvokeTableGenerator(Log);
var icall = new IcallTableGenerator(Log);
var pinvoke = new PInvokeTableGenerator(FixupSymbolName, Log);
var icall = new IcallTableGenerator(FixupSymbolName, Log);

IEnumerable<string> cookies = Enumerable.Concat(
pinvoke.Generate(PInvokeModules, Assemblies!, PInvokeOutputPath!),
Expand All @@ -80,4 +77,37 @@ private void ExecuteInternal()
? new string[] { PInvokeOutputPath, IcallOutputPath, InterpToNativeOutputPath }
: new string[] { PInvokeOutputPath, InterpToNativeOutputPath };
}

public string FixupSymbolName(string name)
{
if (_symbolNameFixups.TryGetValue(name, out string? fixedName))
return fixedName;

UTF8Encoding utf8 = new();
byte[] bytes = utf8.GetBytes(name);
StringBuilder sb = new();

foreach (byte b in bytes)
{
if ((b >= (byte)'0' && b <= (byte)'9') ||
(b >= (byte)'a' && b <= (byte)'z') ||
(b >= (byte)'A' && b <= (byte)'Z') ||
(b == (byte)'_'))
{
sb.Append((char)b);
}
else if (s_charsToReplace.Contains((char)b))
{
sb.Append('_');
}
else
{
sb.Append($"_{b:X}_");
}
}

fixedName = sb.ToString();
_symbolNameFixups[name] = fixedName;
return fixedName;
}
}
68 changes: 16 additions & 52 deletions src/tasks/WasmAppBuilder/PInvokeTableGenerator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,16 @@

internal sealed class PInvokeTableGenerator
{
private static readonly char[] s_charsToReplace = new[] { '.', '-', '+' };
private readonly Dictionary<Assembly, bool> _assemblyDisableRuntimeMarshallingAttributeCache = new();

private TaskLoggingHelper Log { get; set; }
private readonly Func<string, string> _fixupSymbolName;

public PInvokeTableGenerator(TaskLoggingHelper log) => Log = log;
public PInvokeTableGenerator(Func<string, string> fixupSymbolName, TaskLoggingHelper log)
{
Log = log;
_fixupSymbolName = fixupSymbolName;
}

public IEnumerable<string> Generate(string[] pinvokeModules, string[] assemblies, string outputPath)
{
Expand Down Expand Up @@ -234,14 +238,14 @@ private void EmitPInvokeTable(StreamWriter w, Dictionary<string, string> modules

foreach (var module in modules.Keys)
{
string symbol = ModuleNameToId(module) + "_imports";
string symbol = _fixupSymbolName(module) + "_imports";
w.WriteLine("static PinvokeImport " + symbol + " [] = {");

var assemblies_pinvokes = pinvokes.
Where(l => l.Module == module && !l.Skip).
OrderBy(l => l.EntryPoint).
GroupBy(d => d.EntryPoint).
Select(l => "{\"" + FixupSymbolName(l.Key) + "\", " + FixupSymbolName(l.Key) + "}, " +
Select(l => "{\"" + _fixupSymbolName(l.Key) + "\", " + _fixupSymbolName(l.Key) + "}, " +
"// " + string.Join(", ", l.Select(c => c.Method.DeclaringType!.Module!.Assembly!.GetName()!.Name!).Distinct().OrderBy(n => n)));

foreach (var pinvoke in assemblies_pinvokes)
Expand All @@ -255,7 +259,7 @@ private void EmitPInvokeTable(StreamWriter w, Dictionary<string, string> modules
w.Write("static void *pinvoke_tables[] = { ");
foreach (var module in modules.Keys)
{
string symbol = ModuleNameToId(module) + "_imports";
string symbol = _fixupSymbolName(module) + "_imports";
w.Write(symbol + ",");
}
w.WriteLine("};");
Expand All @@ -266,18 +270,6 @@ private void EmitPInvokeTable(StreamWriter w, Dictionary<string, string> modules
}
w.WriteLine("};");

static string ModuleNameToId(string name)
{
if (name.IndexOfAny(s_charsToReplace) < 0)
return name;

string fixedName = name;
foreach (char c in s_charsToReplace)
fixedName = fixedName.Replace(c, '_');

return fixedName;
}

static bool ShouldTreatAsVariadic(PInvoke[] candidates)
{
if (candidates.Length < 2)
Expand All @@ -295,43 +287,15 @@ static bool ShouldTreatAsVariadic(PInvoke[] candidates)
}
}

private static string FixupSymbolName(string name)
{
UTF8Encoding utf8 = new();
byte[] bytes = utf8.GetBytes(name);
StringBuilder sb = new();

foreach (byte b in bytes)
{
if ((b >= (byte)'0' && b <= (byte)'9') ||
(b >= (byte)'a' && b <= (byte)'z') ||
(b >= (byte)'A' && b <= (byte)'Z') ||
(b == (byte)'_'))
{
sb.Append((char)b);
}
else if (s_charsToReplace.Contains((char)b))
{
sb.Append('_');
}
else
{
sb.Append($"_{b:X}_");
}
}

return sb.ToString();
}

private static string SymbolNameForMethod(MethodInfo method)
private string SymbolNameForMethod(MethodInfo method)
{
StringBuilder sb = new();
Type? type = method.DeclaringType;
sb.Append($"{type!.Module!.Assembly!.GetName()!.Name!}_");
sb.Append($"{(type!.IsNested ? type!.FullName : type!.Name)}_");
sb.Append(method.Name);

return FixupSymbolName(sb.ToString());
return _fixupSymbolName(sb.ToString());
}

private static string MapType(Type t) => t.Name switch
Expand Down Expand Up @@ -374,7 +338,7 @@ private static bool TryIsMethodGetParametersUnsupported(MethodInfo method, [NotN
{
// FIXME: System.Reflection.MetadataLoadContext can't decode function pointer types
// https://github.com/dotnet/runtime/issues/43791
sb.Append($"int {FixupSymbolName(pinvoke.EntryPoint)} (int, int, int, int, int);");
sb.Append($"int {_fixupSymbolName(pinvoke.EntryPoint)} (int, int, int, int, int);");
return sb.ToString();
}

Expand All @@ -390,7 +354,7 @@ private static bool TryIsMethodGetParametersUnsupported(MethodInfo method, [NotN
}

sb.Append(MapType(method.ReturnType));
sb.Append($" {FixupSymbolName(pinvoke.EntryPoint)} (");
sb.Append($" {_fixupSymbolName(pinvoke.EntryPoint)} (");
int pindex = 0;
var pars = method.GetParameters();
foreach (var p in pars)
Expand All @@ -404,7 +368,7 @@ private static bool TryIsMethodGetParametersUnsupported(MethodInfo method, [NotN
return sb.ToString();
}

private static void EmitNativeToInterp(StreamWriter w, ref List<PInvokeCallback> callbacks)
private void EmitNativeToInterp(StreamWriter w, ref List<PInvokeCallback> callbacks)
{
// Generate native->interp entry functions
// These are called by native code, so they need to obtain
Expand Down Expand Up @@ -450,7 +414,7 @@ private static void EmitNativeToInterp(StreamWriter w, ref List<PInvokeCallback>

bool is_void = method.ReturnType.Name == "Void";

string module_symbol = method.DeclaringType!.Module!.Assembly!.GetName()!.Name!.Replace(".", "_");
string module_symbol = _fixupSymbolName(method.DeclaringType!.Module!.Assembly!.GetName()!.Name!);
uint token = (uint)method.MetadataToken;
string class_name = method.DeclaringType.Name;
string method_name = method.Name;
Expand Down Expand Up @@ -517,7 +481,7 @@ private static void EmitNativeToInterp(StreamWriter w, ref List<PInvokeCallback>
foreach (var cb in callbacks)
{
var method = cb.Method;
string module_symbol = method.DeclaringType!.Module!.Assembly!.GetName()!.Name!.Replace(".", "_");
string module_symbol = _fixupSymbolName(method.DeclaringType!.Module!.Assembly!.GetName()!.Name!);
string class_name = method.DeclaringType.Name;
string method_name = method.Name;
w.WriteLine($"\"{module_symbol}_{class_name}_{method_name}\",");
Expand Down