diff --git a/src/Zomp.SyncMethodGenerator/AsyncToSyncRewriter.cs b/src/Zomp.SyncMethodGenerator/AsyncToSyncRewriter.cs index fd7324a..e175e96 100644 --- a/src/Zomp.SyncMethodGenerator/AsyncToSyncRewriter.cs +++ b/src/Zomp.SyncMethodGenerator/AsyncToSyncRewriter.cs @@ -73,6 +73,7 @@ internal sealed class AsyncToSyncRewriter(SemanticModel semanticModel, bool disa private readonly HashSet removedParameters = []; private readonly Dictionary renamedLocalFunctions = []; private readonly ImmutableArray.Builder diagnostics = ImmutableArray.CreateBuilder(); + private readonly Stack replaceInInvocation = new(); private enum SyncOnlyDirectiveType { @@ -103,8 +104,8 @@ private enum SpecialMethod ExtensionExprSymbol? GetExtensionExprSymbol(InvocationExpressionSyntax invocation) => GetSymbol(invocation) is not IMethodSymbol - { IsExtensionMethod: true, ReducedFrom: { } reducedFrom, ReturnType: { } returnType } - ? null : new(invocation, reducedFrom, returnType); + { IsExtensionMethod: true, ReturnType: { } returnType } + ? null : new(invocation, returnType); while (curNode is { WhenNotNull: { } whenNotNull }) { @@ -146,6 +147,7 @@ BinaryExpressionSyntax CheckNull(ExpressionSyntax expr) => BinaryExpression( var statements = new List(); var parameter = Identifier("param"); + if (leftOfTheDot is IdentifierNameSyntax ins && chain.Count == 1) { toCheckForNullExpr = ins; @@ -168,7 +170,7 @@ BinaryExpressionSyntax CheckNull(ExpressionSyntax expr) => BinaryExpression( for (var i = 0; i < chain.Count; i++) { - var (callSymbol, reducedSymbol, returnType) = chain[i]; + var (callSymbol, returnType) = chain[i]; ExpressionSyntax firstArgument = funcArgumentType.IsValueType ? MemberAccessExpression( @@ -176,11 +178,12 @@ BinaryExpressionSyntax CheckNull(ExpressionSyntax expr) => BinaryExpression( toCheckForNullExpr, IdentifierName(nameof(Nullable.Value))) : toCheckForNullExpr; - var unwrappedExpr = UnwrapExtension(callSymbol, /*Fixme*/ false, reducedSymbol, firstArgument); + replaceInInvocation.Push(firstArgument); + var unwrappedExpr = (InvocationExpressionSyntax)Visit(callSymbol); if (i == chain.Count - 1) { - ExpressionSyntax condition = returnType.IsValueType + ExpressionSyntax condition = funcArgumentType.IsValueType ? PrefixUnaryExpression( SyntaxKind.LogicalNotExpression, MemberAccessExpression( @@ -191,8 +194,8 @@ BinaryExpressionSyntax CheckNull(ExpressionSyntax expr) => BinaryExpression( var castTo = MaybeNullableType(ProcessSymbol(returnType), returnType.IsValueType); var conditional = ConditionalExpression( - condition, - CastExpression(castTo, LiteralExpression(SyntaxKind.NullLiteralExpression)), + condition.AppendSpace(), + CastExpression(castTo, LiteralExpression(SyntaxKind.NullLiteralExpression)).PrependSpace().AppendSpace(), unwrappedExpr.PrependSpace()); lastExpression = conditional; @@ -204,12 +207,13 @@ BinaryExpressionSyntax CheckNull(ExpressionSyntax expr) => BinaryExpression( var returnNullStatement = ReturnStatement(LiteralExpression(SyntaxKind.NullLiteralExpression).PrependSpace()); var ifBlock = ((ICollection)[returnNullStatement]).CreateBlock(3); - statements.Add(IfStatement(CheckNull(toCheckForNullExpr), ifBlock)); + var ifStatement = IfStatement(Token(SyntaxKind.IfKeyword), Token(SyntaxKind.OpenParenToken).PrependSpace(), CheckNull(toCheckForNullExpr), Token(SyntaxKind.CloseParenToken), ifBlock, null); + statements.Add(ifStatement); var toCheckForNull = Identifier($"check{i}"); var localType = ProcessSymbol(returnType); // reduced will return generic - var declarator = VariableDeclarator(toCheckForNull.AppendSpace(), null, EqualsValueClause(unwrappedExpr)); + var declarator = VariableDeclarator(toCheckForNull.AppendSpace(), null, EqualsValueClause(unwrappedExpr.PrependSpace())); var declaration = VariableDeclaration(localType.AppendSpace(), SeparatedList([declarator])); var intermediaryNullCheck = LocalDeclarationStatement(declaration); @@ -557,19 +561,31 @@ List RemoveFirstEndIf(SyntaxTriviaList list) newName = GetNewName(methodSymbol); } + var reducedFromExtensionMethod = methodSymbol.IsExtensionMethod ? methodSymbol.ReducedFrom : null; + + // Handle non null conditional access expression eg. arr?.First() + if (@base.Expression is MemberBindingExpressionSyntax + && reducedFromExtensionMethod is not null + && replaceInInvocation.Count > 0) + { + return UnwrapExtension(@base, isMemory, reducedFromExtensionMethod, replaceInInvocation.Pop()); + } + if (@base.Expression is not MemberAccessExpressionSyntax { } memberAccess) { return isMemory ? AppendSpan(@base) : @base; } + // Handle .ConfigureAwait() and return the part in front of it if (IsTaskExtension(methodSymbol)) { return memberAccess.Expression; } - if (methodSymbol is { IsExtensionMethod: true, ReducedFrom: { } reducedFrom }) + // Handle all other extension methods eg. arr.First() + if (reducedFromExtensionMethod is not null) { - return UnwrapExtension(@base, isMemory, reducedFrom, memberAccess.Expression); + return UnwrapExtension(@base, isMemory, reducedFromExtensionMethod, memberAccess.Expression); } if (memberAccess.Name is not SimpleNameSyntax { Identifier.ValueText: { } name }) @@ -2081,7 +2097,7 @@ public SyntaxList PostProcess(SyntaxList state } } - private sealed record ExtensionExprSymbol(InvocationExpressionSyntax InvocationExpression, IMethodSymbol ReducedFrom, ITypeSymbol ReturnType); + private sealed record ExtensionExprSymbol(InvocationExpressionSyntax InvocationExpression, ITypeSymbol ReturnType); private sealed record RemoveArgumentContext(bool IsNegated = false); } diff --git a/tests/GenerationSandbox.Tests/GenerationSandbox.Tests.csproj b/tests/GenerationSandbox.Tests/GenerationSandbox.Tests.csproj index 49fd226..fc3fd7a 100644 --- a/tests/GenerationSandbox.Tests/GenerationSandbox.Tests.csproj +++ b/tests/GenerationSandbox.Tests/GenerationSandbox.Tests.csproj @@ -2,11 +2,11 @@ - $(NoWarn);CA1303;CA1515;CA1812;CA1815;CA1822;CA1852;CA2000;CA5394 + $(NoWarn);CA1303;CA1508;CA1515;CA1812;CA1815;CA1822;CA1829;CA1852;CA2000;CA5394 $(NoWarn);CS0219;CS0162;CS1998;CS8603;CS8619 - $(NoWarn);IDE0005;IDE0011;IDE0035;IDE0058;IDE0060;IDE0065 + $(NoWarn);IDE0004;IDE0005;IDE0011;IDE0035;IDE0041;IDE0058;IDE0060;IDE0065 $(NoWarn);RS1035 - $(NoWarn);SA1200;SA1201;SA1400;SA1402;SA1403;SA1404;SA1601 + $(NoWarn);SA1200;SA1201;SA1204;SA1400;SA1402;SA1403;SA1404;SA1601 false net8.0;net6.0 $(TargetFrameworks);net472 diff --git a/tests/Generator.Tests/ConditionalExtensionTests.cs b/tests/Generator.Tests/ConditionalExtensionTests.cs index d733932..4ba186e 100644 --- a/tests/Generator.Tests/ConditionalExtensionTests.cs +++ b/tests/Generator.Tests/ConditionalExtensionTests.cs @@ -34,4 +34,21 @@ public Task LongChained() => """ ?.Where(z => 2 == 0) ?.Where(z => 3 == 0); """.Verify(sourceType: SourceType.MethodBody); + + [Fact] + public Task CheckArrayNullability() => """ +int[]? array = null; +var z = array?.Single(); +""".Verify(sourceType: SourceType.MethodBody); + +#if NET6_0_OR_GREATER + [Fact] + public Task ConditionalToExtension() => """ +[CreateSyncVersion] +public static async Task MethodAsync(IEnumerable? integers) +{ + var res = integers?.Chunk(2).First(); +} +""".Verify(); +#endif } diff --git a/tests/Generator.Tests/NullabilityTests.cs b/tests/Generator.Tests/NullabilityTests.cs index fdec76b..c3f4a88 100644 --- a/tests/Generator.Tests/NullabilityTests.cs +++ b/tests/Generator.Tests/NullabilityTests.cs @@ -89,13 +89,14 @@ partial class MyClass [CreateSyncVersion(OmitNullableDirective = true)] public async Task MethodAsync(int? myInt) { - _ = myInt?.DoSomething().DoSomething(); + _ = myInt?.DoSomething().DoSomething2(); } } internal static class Extension { public static int DoSomething(this int myInt) => myInt; + public static int DoSomething2(this int myInt) => myInt; } """.Verify(sourceType: SourceType.Full); diff --git a/tests/Generator.Tests/Snapshots/ConditionalExtensionTests.CheckArrayNullability#g.verified.cs b/tests/Generator.Tests/Snapshots/ConditionalExtensionTests.CheckArrayNullability#g.verified.cs new file mode 100644 index 0000000..c1f6f08 --- /dev/null +++ b/tests/Generator.Tests/Snapshots/ConditionalExtensionTests.CheckArrayNullability#g.verified.cs @@ -0,0 +1,3 @@ +//HintName: Test.Class.MethodAsync.g.cs +int[]? array = null; +var z = ((object?)array == null ? (int?)null : global::System.Linq.Enumerable.Single(array)); diff --git a/tests/Generator.Tests/Snapshots/ConditionalExtensionTests.ConditionalToExtension#g.verified.cs b/tests/Generator.Tests/Snapshots/ConditionalExtensionTests.ConditionalToExtension#g.verified.cs new file mode 100644 index 0000000..148dc59 --- /dev/null +++ b/tests/Generator.Tests/Snapshots/ConditionalExtensionTests.ConditionalToExtension#g.verified.cs @@ -0,0 +1,5 @@ +//HintName: Test.Class.MethodAsync.g.cs +public static void Method(global::System.Collections.Generic.IEnumerable? integers) +{ + var res = ((object?)integers == null ? (int[]?)null : global::System.Linq.Enumerable.First(global::System.Linq.Enumerable.Chunk(integers, 2))); +} diff --git a/tests/Generator.Tests/Snapshots/ConditionalExtensionTests.LongChained#g.verified.cs b/tests/Generator.Tests/Snapshots/ConditionalExtensionTests.LongChained#g.verified.cs index 4c5877b..14a1ddd 100644 --- a/tests/Generator.Tests/Snapshots/ConditionalExtensionTests.LongChained#g.verified.cs +++ b/tests/Generator.Tests/Snapshots/ConditionalExtensionTests.LongChained#g.verified.cs @@ -2,29 +2,29 @@ global::System.Reflection.Assembly? a = null; _ = ((global::System.Func?>)((param)=> { - if((object?)param == null) + if ((object?)param == null) { return null; } - global::System.Collections.Generic.IEnumerable check0 =global::System.Reflection.CustomAttributeExtensions.GetCustomAttributes(param); - if((object?)check0 == null) + global::System.Collections.Generic.IEnumerable check0 = global::System.Reflection.CustomAttributeExtensions.GetCustomAttributes(param); + if ((object?)check0 == null) { return null; } - global::System.Collections.Generic.List check1 =global::System.Linq.Enumerable.ToList(check0); - if((object?)check1 == null) + global::System.Collections.Generic.List check1 = global::System.Linq.Enumerable.ToList(check0); + if ((object?)check1 == null) { return null; } - global::System.Collections.Generic.IEnumerable check2 =global::System.Linq.Enumerable.Where(check1, z => 1 == 0); - if((object?)check2 == null) + global::System.Collections.Generic.IEnumerable check2 = global::System.Linq.Enumerable.Where(check1, z => 1 == 0); + if ((object?)check2 == null) { return null; } - global::System.Collections.Generic.IEnumerable check3 =global::System.Linq.Enumerable.Where(check2, z => 2 == 0); - return (object?)check3 == null?(global::System.Collections.Generic.IEnumerable?)null: global::System.Linq.Enumerable.Where(check3, z => 3 == 0); + global::System.Collections.Generic.IEnumerable check3 = global::System.Linq.Enumerable.Where(check2, z => 2 == 0); + return (object?)check3 == null ? (global::System.Collections.Generic.IEnumerable?)null : global::System.Linq.Enumerable.Where(check3, z => 3 == 0); }))(a); diff --git a/tests/Generator.Tests/Snapshots/ConditionalExtensionTests.NullConditionalExtension#g.verified.cs b/tests/Generator.Tests/Snapshots/ConditionalExtensionTests.NullConditionalExtension#g.verified.cs index 9183a4a..28378f2 100644 --- a/tests/Generator.Tests/Snapshots/ConditionalExtensionTests.NullConditionalExtension#g.verified.cs +++ b/tests/Generator.Tests/Snapshots/ConditionalExtensionTests.NullConditionalExtension#g.verified.cs @@ -1,3 +1,3 @@ //HintName: Test.Class.MethodAsync.g.cs global::System.Reflection.Assembly? a = null; -_ = ((object?)a == null?(global::System.Collections.Generic.IEnumerable?)null: global::System.Reflection.CustomAttributeExtensions.GetCustomAttributes(a)); +_ = ((object?)a == null ? (global::System.Collections.Generic.IEnumerable?)null : global::System.Reflection.CustomAttributeExtensions.GetCustomAttributes(a)); diff --git a/tests/Generator.Tests/Snapshots/ConditionalExtensionTests.NullConditionalExtensionChained#g.verified.cs b/tests/Generator.Tests/Snapshots/ConditionalExtensionTests.NullConditionalExtensionChained#g.verified.cs index 33aedda..59ec577 100644 --- a/tests/Generator.Tests/Snapshots/ConditionalExtensionTests.NullConditionalExtensionChained#g.verified.cs +++ b/tests/Generator.Tests/Snapshots/ConditionalExtensionTests.NullConditionalExtensionChained#g.verified.cs @@ -2,11 +2,11 @@ global::System.Reflection.Assembly? a = null; _ = ((global::System.Func?>)((param)=> { - if((object?)param == null) + if ((object?)param == null) { return null; } - global::System.Collections.Generic.IEnumerable check0 =global::System.Reflection.CustomAttributeExtensions.GetCustomAttributes(param); - return (object?)check0 == null?(global::System.Collections.Generic.List?)null: global::System.Linq.Enumerable.ToList(check0); + global::System.Collections.Generic.IEnumerable check0 = global::System.Reflection.CustomAttributeExtensions.GetCustomAttributes(param); + return (object?)check0 == null ? (global::System.Collections.Generic.List?)null : global::System.Linq.Enumerable.ToList(check0); }))(a); diff --git a/tests/Generator.Tests/Snapshots/ConditionalExtensionTests.StartWithExpression#g.verified.cs b/tests/Generator.Tests/Snapshots/ConditionalExtensionTests.StartWithExpression#g.verified.cs index 77509be..aff2b63 100644 --- a/tests/Generator.Tests/Snapshots/ConditionalExtensionTests.StartWithExpression#g.verified.cs +++ b/tests/Generator.Tests/Snapshots/ConditionalExtensionTests.StartWithExpression#g.verified.cs @@ -1,2 +1,2 @@ //HintName: Test.Class.MethodAsync.g.cs -_ = ((global::System.Func?>)((param)=>(object?)param == null?(global::System.Collections.Generic.IEnumerable?)null: global::System.Reflection.CustomAttributeExtensions.GetCustomAttributes(param)))(GetType().Assembly); +_ = ((global::System.Func?>)((param)=>(object?)param == null ? (global::System.Collections.Generic.IEnumerable?)null : global::System.Reflection.CustomAttributeExtensions.GetCustomAttributes(param)))(GetType().Assembly); diff --git a/tests/Generator.Tests/Snapshots/ExtensionMethodTests.LeftOfTheDotTest#Tests.Class.MethodAsync.g.verified.cs b/tests/Generator.Tests/Snapshots/ExtensionMethodTests.LeftOfTheDotTest#Tests.Class.MethodAsync.g.verified.cs index 70609de..be8bce9 100644 --- a/tests/Generator.Tests/Snapshots/ExtensionMethodTests.LeftOfTheDotTest#Tests.Class.MethodAsync.g.verified.cs +++ b/tests/Generator.Tests/Snapshots/ExtensionMethodTests.LeftOfTheDotTest#Tests.Class.MethodAsync.g.verified.cs @@ -6,6 +6,6 @@ partial class Class { public void Method() { - _ = ((global::System.Func)((param)=>(object?)param == null?(global::Tests.Bar?)null: global::Tests.BarExtension.DoSomething(param)))(global::Tests.Bar.Create()); + _ = ((global::System.Func)((param)=>(object?)param == null ? (global::Tests.Bar?)null : global::Tests.BarExtension.DoSomething(param)))(global::Tests.Bar.Create()); } } diff --git a/tests/Generator.Tests/Snapshots/NullabilityTests.ConditionalNullable#Test.MyClass.MethodAsync.g.verified.cs b/tests/Generator.Tests/Snapshots/NullabilityTests.ConditionalNullable#Test.MyClass.MethodAsync.g.verified.cs index 1dd5152..4060983 100644 --- a/tests/Generator.Tests/Snapshots/NullabilityTests.ConditionalNullable#Test.MyClass.MethodAsync.g.verified.cs +++ b/tests/Generator.Tests/Snapshots/NullabilityTests.ConditionalNullable#Test.MyClass.MethodAsync.g.verified.cs @@ -5,6 +5,6 @@ partial class MyClass { public void Method(global::System.IO.Stream stream) { - _ = ((object)stream == null?(global::System.IO.Stream)null: global::Test.Extension.DoSomething(stream)); + _ = ((object)stream == null ? (global::System.IO.Stream)null : global::Test.Extension.DoSomething(stream)); } } diff --git a/tests/Generator.Tests/Snapshots/NullabilityTests.ConditionalNullableStructTwice#Test.MyClass.MethodAsync.g.verified.cs b/tests/Generator.Tests/Snapshots/NullabilityTests.ConditionalNullableStructTwice#Test.MyClass.MethodAsync.g.verified.cs index 62c0d01..05cf96a 100644 --- a/tests/Generator.Tests/Snapshots/NullabilityTests.ConditionalNullableStructTwice#Test.MyClass.MethodAsync.g.verified.cs +++ b/tests/Generator.Tests/Snapshots/NullabilityTests.ConditionalNullableStructTwice#Test.MyClass.MethodAsync.g.verified.cs @@ -5,6 +5,6 @@ partial class MyClass { public void Method(int? myInt) { - _ = (!myInt.HasValue?(int?)null: global::Test.Extension.DoSomething(myInt.Value)); + _ = (!myInt.HasValue ? (int?)null : global::Test.Extension.DoSomething2(global::Test.Extension.DoSomething(myInt.Value))); } } diff --git a/tests/Generator.Tests/Snapshots/NullabilityTests.ConditionalNullableTwice#Test.MyClass.MethodAsync.g.verified.cs b/tests/Generator.Tests/Snapshots/NullabilityTests.ConditionalNullableTwice#Test.MyClass.MethodAsync.g.verified.cs index 2d8a413..78f5531 100644 --- a/tests/Generator.Tests/Snapshots/NullabilityTests.ConditionalNullableTwice#Test.MyClass.MethodAsync.g.verified.cs +++ b/tests/Generator.Tests/Snapshots/NullabilityTests.ConditionalNullableTwice#Test.MyClass.MethodAsync.g.verified.cs @@ -7,13 +7,13 @@ public void Method(global::System.IO.Stream stream) { _ = ((global::System.Func)((param)=> { - if((object)param == null) + if ((object)param == null) { return null; } - global::System.IO.Stream check0 =global::Test.Extension.DoSomething(param); - return (object)check0 == null?(global::System.IO.Stream)null: global::Test.Extension.DoSomething(check0); + global::System.IO.Stream check0 = global::Test.Extension.DoSomething(param); + return (object)check0 == null ? (global::System.IO.Stream)null : global::Test.Extension.DoSomething(check0); }))(stream); } }