Skip to content

Commit

Permalink
[Source Gen] Generate function executor code (#1309)
Browse files Browse the repository at this point in the history
* Finally a green test for happy use case.

* WIP

* Handing async and void cases. More tests.

* 🙈

* 🙊

* 🙉

* Clean up and addressing some PR feedback

* Added the extension method to register the executor to service collection.

* Cleanup

* Rebased on main. Fixed package versions (to bump on only preview)

* Minor cleanup

* Updating tests to reflect the renaming of the public types (IInputBindingFeature)

* Cleanup after rebasing on main.

* Removed an unused line (from rebase merge)

* nit fixes based on PR feedback.

* Addressing PR feedback. raw literal string indendation is still driving me crazy!

* Changing the build prop to "FunctionsEnableExecutorSourceGen"

* Bit of a cleanup. Removed an unused type.

* Simplified GetTypesDictionary method

* Added summary comment for IsStatic prop.
  • Loading branch information
kshyju authored Mar 22, 2023
1 parent 37070d8 commit d3822f2
Show file tree
Hide file tree
Showing 17 changed files with 823 additions and 80 deletions.
20 changes: 12 additions & 8 deletions sdk/Sdk.Generators/Constants.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,24 @@ namespace Microsoft.Azure.Functions.Worker.Sdk.Generators
{
internal static class Constants
{
public static class BuildProperties
internal static class Languages
{
internal const string EnableSourceGenProp = "build_property.FunctionsEnableMetadataSourceGen";
internal const string DotnetIsolated = "dotnet-isolated";
}

public static class FileNames
internal static class BuildProperties
{
internal const string EnableSourceGen = "build_property.FunctionsEnableMetadataSourceGen";
internal const string EnablePlaceholder = "build_property.FunctionsEnableExecutorSourceGen";
}

internal static class FileNames
{
internal const string GeneratedFunctionMetadata = "GeneratedFunctionMetadataProvider.g.cs";
internal const string GeneratedFunctionExecutor = "GeneratedFunctionExecutor.g.cs";
}

public static class Types
internal static class Types
{
// Our types
internal const string BindingAttribute = "Microsoft.Azure.Functions.Worker.Extensions.Abstractions.BindingAttribute";
Expand All @@ -30,9 +37,6 @@ public static class Types
// System types
internal const string IEnumerable = "System.Collections.IEnumerable";
internal const string IEnumerableGeneric = "System.Collections.Generic.IEnumerable`1";
internal const string IEnumerableOfString = "System.Collections.Generic.IEnumerable`1<System.String>";
internal const string IEnumerableOfBinary = "System.Collections.Generic.IEnumerable`1<System.Byte[]>";
internal const string IEnumerableOfT = "System.Collections.Generic.IEnumerable`1<T>";
internal const string IEnumerableOfKeyValuePair = "System.Collections.Generic.IEnumerable`1<System.Collections.Generic.KeyValuePair`2<TKey,TValue>>";
internal const string String = "System.String";
internal const string ByteArray = "System.Byte[]";
Expand All @@ -45,7 +49,7 @@ public static class Types
internal const string DictionaryGeneric = "System.Collections.Generic.Dictionary`2";
}

public static class FunctionMetadataBindingProps {
internal static class FunctionMetadataBindingProps {
internal const string ReturnBindingName = "$return";
internal const string HttpResponseBindingName = "HttpResponse";
internal const string IsBatchedKey = "IsBatched";
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the MIT License. See License.txt in the project root for license information.

using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading;

namespace Microsoft.Azure.Functions.Worker.Sdk.Generators
{
public partial class FunctionExecutorGenerator
{
internal static class Emitter
{
internal static string Emit(IEnumerable<ExecutableFunction> functions, CancellationToken cancellationToken)
{
string result = $$"""
// <auto-generated/>
using System;
using System.Threading.Tasks;
using System.Collections.Generic;
using Microsoft.Extensions.Hosting;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Azure.Functions.Worker.Context.Features;
using Microsoft.Azure.Functions.Worker.Invocation;
namespace Microsoft.Azure.Functions.Worker
{
internal class DirectFunctionExecutor : IFunctionExecutor
{
private readonly IFunctionActivator _functionActivator;
{{GetTypesDictionary(functions)}}
public DirectFunctionExecutor(IFunctionActivator functionActivator)
{
_functionActivator = functionActivator ?? throw new ArgumentNullException(nameof(functionActivator));
}

public async ValueTask ExecuteAsync(FunctionContext context)
{
{{GetMethodBody(functions)}}
}
}
public static class FunctionExecutorHostBuilderExtensions
{
///<summary>
/// Configures an optimized function executor to the invocation pipeline.
///</summary>
public static IHostBuilder ConfigureGeneratedFunctionExecutor(this IHostBuilder builder)
{
return builder.ConfigureServices(s =>
{
s.AddSingleton<IFunctionExecutor, DirectFunctionExecutor>();
});
}
}
}
""";

return result;
}

private static string GetTypesDictionary(IEnumerable<ExecutableFunction> functions)
{
var classNames = functions.Where(f => !f.IsStatic).Select(f => f.ParentFunctionClassName).Distinct();
if (!classNames.Any())
{
return """

""";
}

return $$"""
private readonly Dictionary<string, Type> types = new()
{
{{string.Join("\n", classNames.Select(c => $$""" { "{{c}}", Type.GetType("{{c}}")! }"""))}},
};

""";
}

private static string GetMethodBody(IEnumerable<ExecutableFunction> functions)
{
var sb = new StringBuilder();
sb.Append(
"""
var inputBindingFeature = context.Features.Get<IFunctionInputBindingFeature>()!;
var inputBindingResult = await inputBindingFeature.BindFunctionInputAsync(context)!;
var inputArguments = inputBindingResult.Values;

""");
foreach (ExecutableFunction function in functions)
{
sb.Append($$"""

if (string.Equals(context.FunctionDefinition.EntryPoint, "{{function.EntryPoint}}", StringComparison.OrdinalIgnoreCase))
{
""");

int functionParamCounter = 0;
var functionParamList = new List<string>();
foreach (var argumentTypeName in function.ParameterTypeNames)
{
functionParamList.Add($"({argumentTypeName})inputArguments[{functionParamCounter++}]");
}
var methodParamsStr = string.Join(", ", functionParamList);

if (!function.IsStatic)
{
sb.Append($$"""

var instanceType = types["{{function.ParentFunctionClassName}}"];
var i = _functionActivator.CreateInstance(instanceType, context) as {{function.ParentFunctionClassName}};
""");
}

sb.Append(@"
");

if (function.IsReturnValueAssignable)
{
sb.Append(@$"context.GetInvocationResult().Value = ");
}
if (function.ShouldAwait)
{
sb.Append("await ");
}

sb.Append(function.IsStatic
? @$"{function.ParentFunctionClassName}.{function.MethodName}({methodParamsStr});
}}"
: $@"i.{function.MethodName}({methodParamsStr});
}}");
}

return sb.ToString();
}
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the MIT License. See License.txt in the project root for license information.

using System.Collections.Generic;
using System.Linq;

namespace Microsoft.Azure.Functions.Worker.Sdk.Generators
{
/// <summary>
/// A type which holds information about the functions which can be executed from an invocation.
/// </summary>
internal class ExecutableFunction
{
/// <summary>
/// False if the function returns Task or void.
/// </summary>
internal bool IsReturnValueAssignable { set; get; }

/// <summary>
/// Whether the function should be awaited or not for getting the result of execution.
/// </summary>
internal bool ShouldAwait { get; set; }

/// <summary>
/// The method name (which is part of EntryPoint property value).
/// </summary>
internal string MethodName { get; set; } = null!;

/// <summary>
/// A value indicating whether the function is static or not.
/// </summary>
internal bool IsStatic { get; set; }

/// <summary>
/// Ex: MyNamespace.MyClass.MyMethodName
/// </summary>
internal string EntryPoint { get; set; } = null!;

/// <summary>
/// Fully qualified type name of the parent class.
/// Ex: MyNamespace.MyClass
/// </summary>
internal string ParentFunctionClassName { get; set; } = null!;

/// <summary>
/// A collection of fully qualified type names of the parameters of the function.
/// </summary>
internal IEnumerable<string> ParameterTypeNames { set; get; } = Enumerable.Empty<string>();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the MIT License. See License.txt in the project root for license information.

using System.Collections.Generic;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp;
using Microsoft.CodeAnalysis.CSharp.Syntax;

namespace Microsoft.Azure.Functions.Worker.Sdk.Generators
{
public partial class FunctionExecutorGenerator
{
internal sealed class Parser
{
private readonly GeneratorExecutionContext _context;
private readonly KnownTypes _knownTypes;

internal Parser(GeneratorExecutionContext context)
{
_context = context;
_knownTypes = new KnownTypes(_context.Compilation);
}

private Compilation Compilation => _context.Compilation;

internal ICollection<ExecutableFunction> GetFunctions(List<MethodDeclarationSyntax> methods)
{
var functionList = new List<ExecutableFunction>();

foreach (MethodDeclarationSyntax method in methods)
{
_context.CancellationToken.ThrowIfCancellationRequested();
var model = Compilation.GetSemanticModel(method.SyntaxTree);

if (!FunctionsUtil.IsValidFunctionMethod(_context, Compilation, model, method,
out _))
{
continue;
}

var methodName = method.Identifier.Text;
var methodParameterList = new List<string>(method.ParameterList.Parameters.Count);

foreach (var methodParam in method.ParameterList.Parameters)
{
if (model.GetDeclaredSymbol(methodParam) is not IParameterSymbol parameterSymbol)
{
continue;
}

methodParameterList.Add(parameterSymbol.Type.ToDisplayString());
}

var methodSymbol = model.GetDeclaredSymbol(method)!;
var fullyQualifiedClassName = methodSymbol.ContainingSymbol.ToDisplayString();

var function = new ExecutableFunction
{
EntryPoint = $"{fullyQualifiedClassName}.{method.Identifier.ValueText}",
ParameterTypeNames = methodParameterList,
MethodName = methodName,
ShouldAwait = IsTaskType(methodSymbol.ReturnType),
IsReturnValueAssignable = IsReturnValueAssignable(methodSymbol),
IsStatic = method.Modifiers.Any(SyntaxKind.StaticKeyword),
ParentFunctionClassName = fullyQualifiedClassName
};

functionList.Add(function);
}

return functionList;
}

/// <summary>
/// Returns true if the symbol is Task/Task of T/ValueTask/ValueTask of T.
/// </summary>
private bool IsTaskType(ITypeSymbol typeSymbol)
{
return
SymbolEqualityComparer.Default.Equals(typeSymbol.OriginalDefinition, _knownTypes.TaskType) ||
SymbolEqualityComparer.Default.Equals(typeSymbol.OriginalDefinition, _knownTypes.TaskOfTType) ||
SymbolEqualityComparer.Default.Equals(typeSymbol.OriginalDefinition, _knownTypes.ValueTaskType) ||
SymbolEqualityComparer.Default.Equals(typeSymbol.OriginalDefinition, _knownTypes.ValueTaskOfTTypeOpt);
}

/// <summary>
/// Is the return value of the method assignable to a variable?
/// Returns True for methods which has Task or void as return type.
/// </summary>
private bool IsReturnValueAssignable(IMethodSymbol methodSymbol)
{
if (methodSymbol.ReturnsVoid)
{
return false;
}

if (SymbolEqualityComparer.Default.Equals(methodSymbol.ReturnType.OriginalDefinition, _knownTypes.TaskType))
{
return false;
}

if (SymbolEqualityComparer.Default.Equals(methodSymbol.ReturnType.OriginalDefinition,
_knownTypes.ValueTaskType))
{
return false;
}

return true;
}
}
}
}
Loading

0 comments on commit d3822f2

Please sign in to comment.