Skip to content

Commit

Permalink
Fix #79 Improve value type check in conditional access expression, fi…
Browse files Browse the repository at this point in the history
…x spacing
  • Loading branch information
virzak committed Jul 6, 2024
1 parent 89dcc1a commit a1c0e95
Show file tree
Hide file tree
Showing 14 changed files with 78 additions and 36 deletions.
40 changes: 28 additions & 12 deletions src/Zomp.SyncMethodGenerator/AsyncToSyncRewriter.cs
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ internal sealed class AsyncToSyncRewriter(SemanticModel semanticModel, bool disa
private readonly HashSet<IParameterSymbol> removedParameters = [];
private readonly Dictionary<string, string> renamedLocalFunctions = [];
private readonly ImmutableArray<ReportedDiagnostic>.Builder diagnostics = ImmutableArray.CreateBuilder<ReportedDiagnostic>();
private readonly Stack<ExpressionSyntax> replaceInInvocation = new();

private enum SyncOnlyDirectiveType
{
Expand Down Expand Up @@ -103,8 +104,8 @@ private enum SpecialMethod

ExtensionExprSymbol? GetExtensionExprSymbol(InvocationExpressionSyntax invocation)
=> GetSymbol(invocation) is not IMethodSymbol
{ IsExtensionMethod: true, ReducedFrom: { } reducedFrom, ReturnType: { } returnType }
? null : new(invocation, reducedFrom, returnType);
{ IsExtensionMethod: true, ReturnType: { } returnType }
? null : new(invocation, returnType);

while (curNode is { WhenNotNull: { } whenNotNull })
{
Expand Down Expand Up @@ -146,6 +147,7 @@ BinaryExpressionSyntax CheckNull(ExpressionSyntax expr) => BinaryExpression(

var statements = new List<StatementSyntax>();
var parameter = Identifier("param");

if (leftOfTheDot is IdentifierNameSyntax ins && chain.Count == 1)
{
toCheckForNullExpr = ins;
Expand All @@ -168,19 +170,20 @@ BinaryExpressionSyntax CheckNull(ExpressionSyntax expr) => BinaryExpression(

for (var i = 0; i < chain.Count; i++)
{
var (callSymbol, reducedSymbol, returnType) = chain[i];
var (callSymbol, returnType) = chain[i];

ExpressionSyntax firstArgument = funcArgumentType.IsValueType
? MemberAccessExpression(
SyntaxKind.SimpleMemberAccessExpression,
toCheckForNullExpr,
IdentifierName(nameof(Nullable<int>.Value)))
: toCheckForNullExpr;
var unwrappedExpr = UnwrapExtension(callSymbol, /*Fixme*/ false, reducedSymbol, firstArgument);
replaceInInvocation.Push(firstArgument);
var unwrappedExpr = (InvocationExpressionSyntax)Visit(callSymbol);

if (i == chain.Count - 1)
{
ExpressionSyntax condition = returnType.IsValueType
ExpressionSyntax condition = funcArgumentType.IsValueType
? PrefixUnaryExpression(
SyntaxKind.LogicalNotExpression,
MemberAccessExpression(
Expand All @@ -191,8 +194,8 @@ BinaryExpressionSyntax CheckNull(ExpressionSyntax expr) => BinaryExpression(
var castTo = MaybeNullableType(ProcessSymbol(returnType), returnType.IsValueType);

var conditional = ConditionalExpression(
condition,
CastExpression(castTo, LiteralExpression(SyntaxKind.NullLiteralExpression)),
condition.AppendSpace(),
CastExpression(castTo, LiteralExpression(SyntaxKind.NullLiteralExpression)).PrependSpace().AppendSpace(),
unwrappedExpr.PrependSpace());

lastExpression = conditional;
Expand All @@ -204,12 +207,13 @@ BinaryExpressionSyntax CheckNull(ExpressionSyntax expr) => BinaryExpression(
var returnNullStatement = ReturnStatement(LiteralExpression(SyntaxKind.NullLiteralExpression).PrependSpace());

var ifBlock = ((ICollection<StatementSyntax>)[returnNullStatement]).CreateBlock(3);
statements.Add(IfStatement(CheckNull(toCheckForNullExpr), ifBlock));
var ifStatement = IfStatement(Token(SyntaxKind.IfKeyword), Token(SyntaxKind.OpenParenToken).PrependSpace(), CheckNull(toCheckForNullExpr), Token(SyntaxKind.CloseParenToken), ifBlock, null);
statements.Add(ifStatement);

var toCheckForNull = Identifier($"check{i}");
var localType = ProcessSymbol(returnType); // reduced will return generic

var declarator = VariableDeclarator(toCheckForNull.AppendSpace(), null, EqualsValueClause(unwrappedExpr));
var declarator = VariableDeclarator(toCheckForNull.AppendSpace(), null, EqualsValueClause(unwrappedExpr.PrependSpace()));
var declaration = VariableDeclaration(localType.AppendSpace(), SeparatedList([declarator]));
var intermediaryNullCheck = LocalDeclarationStatement(declaration);

Expand Down Expand Up @@ -557,19 +561,31 @@ List<SyntaxTrivia> RemoveFirstEndIf(SyntaxTriviaList list)
newName = GetNewName(methodSymbol);
}

var reducedFromExtensionMethod = methodSymbol.IsExtensionMethod ? methodSymbol.ReducedFrom : null;

// Handle non null conditional access expression eg. arr?.First()
if (@base.Expression is MemberBindingExpressionSyntax
&& reducedFromExtensionMethod is not null
&& replaceInInvocation.Count > 0)
{
return UnwrapExtension(@base, isMemory, reducedFromExtensionMethod, replaceInInvocation.Pop());
}

if (@base.Expression is not MemberAccessExpressionSyntax { } memberAccess)
{
return isMemory ? AppendSpan(@base) : @base;
}

// Handle .ConfigureAwait() and return the part in front of it
if (IsTaskExtension(methodSymbol))
{
return memberAccess.Expression;
}

if (methodSymbol is { IsExtensionMethod: true, ReducedFrom: { } reducedFrom })
// Handle all other extension methods eg. arr.First()
if (reducedFromExtensionMethod is not null)
{
return UnwrapExtension(@base, isMemory, reducedFrom, memberAccess.Expression);
return UnwrapExtension(@base, isMemory, reducedFromExtensionMethod, memberAccess.Expression);
}

if (memberAccess.Name is not SimpleNameSyntax { Identifier.ValueText: { } name })
Expand Down Expand Up @@ -2081,7 +2097,7 @@ public SyntaxList<StatementSyntax> PostProcess(SyntaxList<StatementSyntax> state
}
}

private sealed record ExtensionExprSymbol(InvocationExpressionSyntax InvocationExpression, IMethodSymbol ReducedFrom, ITypeSymbol ReturnType);
private sealed record ExtensionExprSymbol(InvocationExpressionSyntax InvocationExpression, ITypeSymbol ReturnType);

private sealed record RemoveArgumentContext(bool IsNegated = false);
}
6 changes: 3 additions & 3 deletions tests/GenerationSandbox.Tests/GenerationSandbox.Tests.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@

<PropertyGroup>
<!--Suppress warninigs alphabetically and numerically-->
<NoWarn>$(NoWarn);CA1303;CA1515;CA1812;CA1815;CA1822;CA1852;CA2000;CA5394</NoWarn>
<NoWarn>$(NoWarn);CA1303;CA1508;CA1515;CA1812;CA1815;CA1822;CA1829;CA1852;CA2000;CA5394</NoWarn>
<NoWarn>$(NoWarn);CS0219;CS0162;CS1998;CS8603;CS8619</NoWarn>
<NoWarn>$(NoWarn);IDE0005;IDE0011;IDE0035;IDE0058;IDE0060;IDE0065</NoWarn>
<NoWarn>$(NoWarn);IDE0004;IDE0005;IDE0011;IDE0035;IDE0041;IDE0058;IDE0060;IDE0065</NoWarn>
<NoWarn>$(NoWarn);RS1035</NoWarn>
<NoWarn>$(NoWarn);SA1200;SA1201;SA1400;SA1402;SA1403;SA1404;SA1601</NoWarn>
<NoWarn>$(NoWarn);SA1200;SA1201;SA1204;SA1400;SA1402;SA1403;SA1404;SA1601</NoWarn>
<ImplicitUsings>false</ImplicitUsings>
<TargetFrameworks>net8.0;net6.0</TargetFrameworks>
<TargetFrameworks Condition="'$(OS)' == 'Windows_NT'">$(TargetFrameworks);net472</TargetFrameworks>
Expand Down
17 changes: 17 additions & 0 deletions tests/Generator.Tests/ConditionalExtensionTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -34,4 +34,21 @@ public Task LongChained() => """
?.Where(z => 2 == 0)
?.Where(z => 3 == 0);
""".Verify(sourceType: SourceType.MethodBody);

[Fact]
public Task CheckArrayNullability() => """
int[]? array = null;
var z = array?.Single();
""".Verify(sourceType: SourceType.MethodBody);

#if NET6_0_OR_GREATER
[Fact]
public Task ConditionalToExtension() => """
[CreateSyncVersion]
public static async Task MethodAsync(IEnumerable<int>? integers)
{
var res = integers?.Chunk(2).First();
}
""".Verify();
#endif
}
3 changes: 2 additions & 1 deletion tests/Generator.Tests/NullabilityTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -89,13 +89,14 @@ partial class MyClass
[CreateSyncVersion(OmitNullableDirective = true)]
public async Task MethodAsync(int? myInt)
{
_ = myInt?.DoSomething().DoSomething();
_ = myInt?.DoSomething().DoSomething2();
}
}

internal static class Extension
{
public static int DoSomething(this int myInt) => myInt;
public static int DoSomething2(this int myInt) => myInt;
}
""".Verify(sourceType: SourceType.Full);

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
//HintName: Test.Class.MethodAsync.g.cs
int[]? array = null;
var z = ((object?)array == null ? (int?)null : global::System.Linq.Enumerable.Single(array));
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
//HintName: Test.Class.MethodAsync.g.cs
public static void Method(global::System.Collections.Generic.IEnumerable<int>? integers)
{
var res = ((object?)integers == null ? (int[]?)null : global::System.Linq.Enumerable.First(global::System.Linq.Enumerable.Chunk(integers, 2)));
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,29 +2,29 @@
global::System.Reflection.Assembly? a = null;
_ = ((global::System.Func<global::System.Reflection.Assembly?,global::System.Collections.Generic.IEnumerable<global::System.Attribute>?>)((param)=>
{
if((object?)param == null)
if ((object?)param == null)
{
return null;
}
global::System.Collections.Generic.IEnumerable<global::System.Attribute> check0 =global::System.Reflection.CustomAttributeExtensions.GetCustomAttributes(param);
if((object?)check0 == null)
global::System.Collections.Generic.IEnumerable<global::System.Attribute> check0 = global::System.Reflection.CustomAttributeExtensions.GetCustomAttributes(param);
if ((object?)check0 == null)
{
return null;
}
global::System.Collections.Generic.List<global::System.Attribute> check1 =global::System.Linq.Enumerable.ToList(check0);
if((object?)check1 == null)
global::System.Collections.Generic.List<global::System.Attribute> check1 = global::System.Linq.Enumerable.ToList(check0);
if ((object?)check1 == null)
{
return null;
}
global::System.Collections.Generic.IEnumerable<global::System.Attribute> check2 =global::System.Linq.Enumerable.Where(check1, z => 1 == 0);
if((object?)check2 == null)
global::System.Collections.Generic.IEnumerable<global::System.Attribute> check2 = global::System.Linq.Enumerable.Where(check1, z => 1 == 0);
if ((object?)check2 == null)
{
return null;
}
global::System.Collections.Generic.IEnumerable<global::System.Attribute> check3 =global::System.Linq.Enumerable.Where(check2, z => 2 == 0);
return (object?)check3 == null?(global::System.Collections.Generic.IEnumerable<global::System.Attribute>?)null: global::System.Linq.Enumerable.Where(check3, z => 3 == 0);
global::System.Collections.Generic.IEnumerable<global::System.Attribute> check3 = global::System.Linq.Enumerable.Where(check2, z => 2 == 0);
return (object?)check3 == null ? (global::System.Collections.Generic.IEnumerable<global::System.Attribute>?)null : global::System.Linq.Enumerable.Where(check3, z => 3 == 0);
}))(a);
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
//HintName: Test.Class.MethodAsync.g.cs
global::System.Reflection.Assembly? a = null;
_ = ((object?)a == null?(global::System.Collections.Generic.IEnumerable<global::System.Attribute>?)null: global::System.Reflection.CustomAttributeExtensions.GetCustomAttributes(a));
_ = ((object?)a == null ? (global::System.Collections.Generic.IEnumerable<global::System.Attribute>?)null : global::System.Reflection.CustomAttributeExtensions.GetCustomAttributes(a));
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@
global::System.Reflection.Assembly? a = null;
_ = ((global::System.Func<global::System.Reflection.Assembly?,global::System.Collections.Generic.List<global::System.Attribute>?>)((param)=>
{
if((object?)param == null)
if ((object?)param == null)
{
return null;
}
global::System.Collections.Generic.IEnumerable<global::System.Attribute> check0 =global::System.Reflection.CustomAttributeExtensions.GetCustomAttributes(param);
return (object?)check0 == null?(global::System.Collections.Generic.List<global::System.Attribute>?)null: global::System.Linq.Enumerable.ToList(check0);
global::System.Collections.Generic.IEnumerable<global::System.Attribute> check0 = global::System.Reflection.CustomAttributeExtensions.GetCustomAttributes(param);
return (object?)check0 == null ? (global::System.Collections.Generic.List<global::System.Attribute>?)null : global::System.Linq.Enumerable.ToList(check0);
}))(a);
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
//HintName: Test.Class.MethodAsync.g.cs
_ = ((global::System.Func<global::System.Reflection.Assembly?,global::System.Collections.Generic.IEnumerable<global::System.Attribute>?>)((param)=>(object?)param == null?(global::System.Collections.Generic.IEnumerable<global::System.Attribute>?)null: global::System.Reflection.CustomAttributeExtensions.GetCustomAttributes(param)))(GetType().Assembly);
_ = ((global::System.Func<global::System.Reflection.Assembly?,global::System.Collections.Generic.IEnumerable<global::System.Attribute>?>)((param)=>(object?)param == null ? (global::System.Collections.Generic.IEnumerable<global::System.Attribute>?)null : global::System.Reflection.CustomAttributeExtensions.GetCustomAttributes(param)))(GetType().Assembly);
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,6 @@ partial class Class
{
public void Method()
{
_ = ((global::System.Func<global::Tests.Bar?,global::Tests.Bar?>)((param)=>(object?)param == null?(global::Tests.Bar?)null: global::Tests.BarExtension.DoSomething(param)))(global::Tests.Bar.Create());
_ = ((global::System.Func<global::Tests.Bar?,global::Tests.Bar?>)((param)=>(object?)param == null ? (global::Tests.Bar?)null : global::Tests.BarExtension.DoSomething(param)))(global::Tests.Bar.Create());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,6 @@ partial class MyClass
{
public void Method(global::System.IO.Stream stream)
{
_ = ((object)stream == null?(global::System.IO.Stream)null: global::Test.Extension.DoSomething(stream));
_ = ((object)stream == null ? (global::System.IO.Stream)null : global::Test.Extension.DoSomething(stream));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,6 @@ partial class MyClass
{
public void Method(int? myInt)
{
_ = (!myInt.HasValue?(int?)null: global::Test.Extension.DoSomething(myInt.Value));
_ = (!myInt.HasValue ? (int?)null : global::Test.Extension.DoSomething2(global::Test.Extension.DoSomething(myInt.Value)));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,13 @@ public void Method(global::System.IO.Stream stream)
{
_ = ((global::System.Func<global::System.IO.Stream,global::System.IO.Stream>)((param)=>
{
if((object)param == null)
if ((object)param == null)
{
return null;
}
global::System.IO.Stream check0 =global::Test.Extension.DoSomething(param);
return (object)check0 == null?(global::System.IO.Stream)null: global::Test.Extension.DoSomething(check0);
global::System.IO.Stream check0 = global::Test.Extension.DoSomething(param);
return (object)check0 == null ? (global::System.IO.Stream)null : global::Test.Extension.DoSomething(check0);
}))(stream);
}
}

0 comments on commit a1c0e95

Please sign in to comment.