Skip to content

Commit

Permalink
Fix #78 Improve Func<Task> to Action conversion with nested Func and …
Browse files Browse the repository at this point in the history
…explicit Invoke
  • Loading branch information
virzak committed Jul 7, 2024
1 parent 7e6da6e commit bbb542f
Show file tree
Hide file tree
Showing 8 changed files with 122 additions and 26 deletions.
86 changes: 63 additions & 23 deletions src/Zomp.SyncMethodGenerator/AsyncToSyncRewriter.cs
Original file line number Diff line number Diff line change
Expand Up @@ -448,23 +448,19 @@ List<SyntaxTrivia> RemoveFirstEndIf(SyntaxTriviaList list)
{
var @base = (ParameterSyntax)base.VisitParameter(node)!;

if (node.Type is not null && @base.Type is not null)
{
var originalType = node.Type;
var variableType = GetSymbol(originalType);
return node.Type is { } originalType && @base.Type is not null

if (variableType is INamedTypeSymbol { IsGenericType: true } namedTypeSymbol
&& GetNameWithoutTypeParams(namedTypeSymbol) is SystemFunc)
{
if (ConvertFuncToAction(@base.Type, namedTypeSymbol) is { } newTypeSyntax)
{
return @base.WithType(newTypeSyntax.WithTriviaFrom(originalType));
}
}
}
// If type is generic
&& GetSymbol(originalType) is INamedTypeSymbol { IsGenericType: true } namedTypeSymbol

// And it is System.Func
&& GetNameWithoutTypeParams(namedTypeSymbol) is SystemFunc

return node.Type is null || TypeAlreadyQualified(node.Type) ? @base
: @base.WithType(ProcessType(node.Type)).WithTriviaFrom(@base);
// And can be converter to Action
&& ConvertFuncToAction(@base.Type, namedTypeSymbol) is { } newTypeSyntax
? @base.WithType(newTypeSyntax.WithTriviaFrom(originalType))
: (node.Type is null || TypeAlreadyQualified(node.Type) ? @base
: @base.WithType(ProcessType(node.Type)).WithTriviaFrom(@base));
}

/// <inheritdoc/>
Expand Down Expand Up @@ -1624,30 +1620,56 @@ private static string GetNewName(IMethodSymbol methodSymbol)
{
var typeArgs = namedTypeSymbol.TypeArguments;

if (typeArgs[^1].ToString() is not TaskType and not ValueTaskType
if (typeArgs[^1] is not INamedTypeSymbol lastTypeArgument
|| originalType is not GenericNameSyntax gns)
{
return null;
}

if (GetNameWithoutTypeParams(lastTypeArgument) is SystemFunc)
{
if (ConvertFuncToAction(gns.TypeArgumentList.Arguments[^1], lastTypeArgument) is not { } child)
{
return null;
}

var newList = (List<TypeSyntax>)[.. gns.TypeArgumentList.Arguments];
newList.RemoveAt(newList.Count - 1);
newList.Add(child);
var newSeparatedList = SeparatedList(newList, gns.TypeArgumentList.Arguments.GetSeparators());
var res = gns.WithTypeArgumentList(TypeArgumentList(newSeparatedList));
return res;
}
else if (lastTypeArgument.ToString() is not TaskType and not ValueTaskType)
{
return null;
}

TypeSyntax newTypeSyntax;
var newType = Global("System.Action");

var list = new List<TypeSyntax>();
if (typeArgs.Length > 1)
{
// Func<something, Task> => Action<something>
for (var i = 0; i < typeArgs.Length - 1; i++)
{
list.Add(ProcessSymbol(typeArgs[i]));
}

var originalSeparators = gns.TypeArgumentList.Arguments.GetSeparators();
var originalSeparators = (List<SyntaxToken>)[.. gns.TypeArgumentList.Arguments.GetSeparators()];

if (originalSeparators.Count > 0 && list.Count == originalSeparators.Count)
{
originalSeparators.RemoveAt(originalSeparators.Count - 1);
}

var separatedList = SeparatedList(list, originalSeparators);
newTypeSyntax = GenericName(Identifier(newType), TypeArgumentList(separatedList));
}
else
{
// Func<Task> => Action
newTypeSyntax = IdentifierName(newType);
}

Expand Down Expand Up @@ -1950,16 +1972,34 @@ private bool DropInvocation(InvocationExpressionSyntax invocation)
}

if (symbol is IMethodSymbol methodSymbol
&& expression is MemberAccessExpressionSyntax memberAccessExpression)
&& expression is MemberAccessExpressionSyntax memberAccessExpression
&& IsTaskExtension(methodSymbol) && memberAccessExpression.Expression is InvocationExpressionSyntax childInvocation)
{
if (IsTaskExtension(methodSymbol) && memberAccessExpression.Expression is InvocationExpressionSyntax childInvocation)
{
return DropInvocation(childInvocation);
}
return DropInvocation(childInvocation);
}

IParameterSymbol? GetParameter(ISymbol symbol, InvocationExpressionSyntax node) => symbol switch
{
IParameterSymbol ps => ps,
IMethodSymbol { MethodKind: MethodKind.DelegateInvoke }
=> node.Expression switch
{
MemberAccessExpressionSyntax { Expression: InvocationExpressionSyntax mae }
when GetSymbol(mae.Expression) is { } parentSymbol
=> GetParameter(parentSymbol, mae),
MemberAccessExpressionSyntax mae
when GetSymbol(mae.Expression) is IParameterSymbol ps
=> ps,
InvocationExpressionSyntax parentIes
when GetSymbol(parentIes.Expression) is { } parentSymbol
=> GetParameter(parentSymbol, parentIes),
_ => null,
},
_ => null,
};

// Ensure that if a parameter is called, which hasn't been removed, invocation isn't dropped.
return (symbol is not IParameterSymbol ps || removedParameters.Contains(ps)) && HasSymbolAndShouldBeRemoved(invocation);
return (GetParameter(symbol, invocation) is not IParameterSymbol ps || removedParameters.Contains(ps)) && HasSymbolAndShouldBeRemoved(invocation);
}

private bool ShouldRemoveArrowExpression(ArrowExpressionClauseSyntax? arrowNullable)
Expand Down
4 changes: 2 additions & 2 deletions tests/GenerationSandbox.Tests/GenerationSandbox.Tests.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
<!--Suppress warninigs alphabetically and numerically-->
<NoWarn>$(NoWarn);CA1303;CA1508;CA1515;CA1812;CA1815;CA1822;CA1829;CA1852;CA2000;CA5394</NoWarn>
<NoWarn>$(NoWarn);CS0219;CS0162;CS1998;CS8603;CS8619</NoWarn>
<NoWarn>$(NoWarn);IDE0004;IDE0005;IDE0011;IDE0035;IDE0041;IDE0058;IDE0060;IDE0065</NoWarn>
<NoWarn>$(NoWarn);IDE0004;IDE0005;IDE0011;IDE0035;IDE0040;IDE0041;IDE0058;IDE0060;IDE0065</NoWarn>
<NoWarn>$(NoWarn);RS1035</NoWarn>
<NoWarn>$(NoWarn);SA1200;SA1201;SA1204;SA1400;SA1402;SA1403;SA1404;SA1601</NoWarn>
<NoWarn>$(NoWarn);SA1200;SA1201;SA1202;SA1204;SA1400;SA1402;SA1403;SA1404;SA1601</NoWarn>
<ImplicitUsings>false</ImplicitUsings>
<TargetFrameworks>net8.0;net6.0</TargetFrameworks>
<TargetFrameworks Condition="'$(OS)' == 'Windows_NT'">$(TargetFrameworks);net472</TargetFrameworks>
Expand Down
38 changes: 37 additions & 1 deletion tests/Generator.Tests/DelegateTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,42 @@ async Task MethodAsync(Func<int, Task> bar)
{
await bar(5);
}
""".Verify();

[Fact]
public Task FuncToActionInParameterExplicitInvoke() => """
[CreateSyncVersion]
async Task MethodAsync(Func<Task> bar)
{
await bar.Invoke();
}
""".Verify();

[Fact]
public Task FuncToActionInParameterTwice() => """
[CreateSyncVersion]
async Task MethodAsync(Func<Func<Task>> bar)
{
await bar()();
}
""".Verify();

[Fact]
public Task FuncToActionInParameterTwiceExplicitInvoke() => """
[CreateSyncVersion]
async Task MethodAsync(Func<Func<Task>> bar)
{
await bar().Invoke();
}
""".Verify();

[Fact]
public Task FuncToActionInParameterTwiceWithArgument() => """
[CreateSyncVersion]
async Task MethodAsync(Func<int, Func<Task>> bar)
{
await bar(9)();
}
""".Verify();

[Theory]
Expand Down Expand Up @@ -63,7 +99,7 @@ public Task AsyncDelegateWithIProgress(string iProgressArg) => $"""

[InlineData("null")]
[InlineData("new Progress<int>()")]
public Task AsyncDelegateWithMullableIProgress(string iProgressArg) => $"""
public Task AsyncDelegateWithNullableIProgress(string iProgressArg) => $"""
Func<IProgress<int>?, CancellationToken, Task<int>> delAsync = async (p, ct) => await Task.FromResult(2);
var result = await delAsync({iProgressArg}, CancellationToken.None);
""".Verify(disableUnique: true, sourceType: SourceType.MethodBody);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
//HintName: Test.Class.MethodAsync.g.cs
void Method(global::System.Action bar)
{
bar.Invoke();
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
//HintName: Test.Class.MethodAsync.g.cs
void Method(global::System.Func<global::System.Action> bar)
{
bar()();
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
//HintName: Test.Class.MethodAsync.g.cs
void Method(global::System.Func<global::System.Action> bar)
{
bar().Invoke();
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
//HintName: Test.Class.MethodAsync.g.cs
void Method(global::System.Func<int, global::System.Action> bar)
{
bar(9)();
}

0 comments on commit bbb542f

Please sign in to comment.