Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement FAR for GetAwaiter methods #28230

Merged
merged 2 commits into from
Jul 3, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,83 @@ Namespace Microsoft.CodeAnalysis.Editor.UnitTests.FindReferences
Await TestAPIAndFeature(input)
End Function

<WpfFact, Trait(Traits.Feature, Traits.Features.FindReferences)>
<WorkItem(18963, "https://github.com/dotnet/roslyn/issues/18963")>
Public Async Function FindReferences_GetAwaiter() As Task
Dim input =
<Workspace>
<Project Language="C#" CommonReferences="true">
<Document><![CDATA[
using System.Threading.Tasks;
using System.Runtime.CompilerServices;
public class C
{
public TaskAwaiter<bool> {|Definition:Get$$Awaiter|}() => Task.FromResult(true).GetAwaiter();

static async void M(C c)
{
[|await|] c;
[|await|] c;
}
}
]]></Document>
</Project>
</Workspace>
Await TestAPIAndFeature(input)
End Function

<WpfFact, Trait(Traits.Feature, Traits.Features.FindReferences)>
<WorkItem(18963, "https://github.com/dotnet/roslyn/issues/18963")>
Public Async Function FindReferences_GetAwaiter_VB() As Task
Dim input =
<Workspace>
<Project Language="Visual Basic" CommonReferences="true">
<Document><![CDATA[
Imports System.Threading.Tasks
Imports System.Runtime.CompilerServices
Public Class C
Public Function {|Definition:Get$$Awaiter|}() As TaskAwaiter(Of Boolean)
End Function

Shared Async Sub M(c As C)
[|Await|] c
End Sub
End Class
]]></Document>
</Project>
</Workspace>
Await TestAPIAndFeature(input)
End Function

<WpfFact, Trait(Traits.Feature, Traits.Features.FindReferences)>
<WorkItem(18963, "https://github.com/dotnet/roslyn/issues/18963")>
Public Async Function FindReferences_GetAwaiterInAnotherDocument() As Task
Dim input =
<Workspace>
<Project Language="C#" CommonReferences="true">
<Document><![CDATA[
using System.Threading.Tasks;
using System.Runtime.CompilerServices;
public class C
{
public TaskAwaiter<bool> {|Definition:Get$$Awaiter|}() => Task.FromResult(true).GetAwaiter();
}
]]></Document>
<Document><![CDATA[
class D
{
static async void M(C c)
{
[|await|] c;
[|await|] c;
}
}
]]></Document>
</Project>
</Workspace>
Await TestAPIAndFeature(input)
End Function

<WpfFact, Trait(Traits.Feature, Traits.Features.FindReferences)>
<WorkItem(18963, "https://github.com/dotnet/roslyn/issues/18963")>
Public Async Function FindReferences_Deconstruction() As Task
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,31 @@ class C
End Using
End Sub

<Fact, Trait(Traits.Feature, Traits.Features.Rename)>
Public Sub RenameGetAwaiterCausesConflict()
Using result = RenameEngineResult.Create(_outputHelper,
<Workspace>
<Project Language="C#" CommonReferences="true">
<Document><![CDATA[
using System.Threading.Tasks;
using System.Runtime.CompilerServices;
public class C
{
public TaskAwaiter<bool> [|Get$$Awaiter|]() => Task.FromResult(true).GetAwaiter();

static async void M(C c)
{
{|awaitconflict:await|} c;
}
}
]]></Document>
</Project>
</Workspace>, renameTo:="GetAwaiter2")

result.AssertLabeledSpansAre("awaitconflict", type:=RelatedLocationType.UnresolvedConflict)
End Using
End Sub

<WorkItem(528966, "http://vstfdevdiv:8080/DevDiv2/DevDiv/_workitems/edit/528966")>
<Fact, Trait(Traits.Feature, Traits.Features.Rename)>
Public Sub RenameMoveNextInVBCausesConflictInForEach()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,17 @@ public ForEachSymbols GetForEachSymbols(SemanticModel semanticModel, SyntaxNode
}
}

public IMethodSymbol GetGetAwaiterMethod(SemanticModel semanticModel, SyntaxNode node)
{
if (node is AwaitExpressionSyntax awaitExpression)
{
var info = semanticModel.GetAwaitExpressionInfo(awaitExpression);
return info.GetAwaiterMethod;
}

return null;
}

public ImmutableArray<IMethodSymbol> GetDeconstructionAssignmentMethods(SemanticModel semanticModel, SyntaxNode node)
{
if (node is AssignmentExpressionSyntax assignment && assignment.IsDeconstruction())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1585,6 +1585,9 @@ public bool IsExpressionOfInvocationExpression(SyntaxNode node)
return node != null && (node.Parent as InvocationExpressionSyntax)?.Expression == node;
}

public bool IsAwaitExpression(SyntaxNode node)
=> node.IsKind(SyntaxKind.AwaitExpression);

public bool IsExpressionOfAwaitExpression(SyntaxNode node)
{
return node != null && (node.Parent as AwaitExpressionSyntax)?.Expression == node;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -996,6 +996,7 @@ public async Task<ImmutableArray<Location>> ComputeImplicitReferenceConflictsAsy
(renameSymbol.Kind == SymbolKind.Method &&
(string.Compare(renameSymbol.Name, WellKnownMemberNames.MoveNextMethodName, StringComparison.OrdinalIgnoreCase) == 0 ||
string.Compare(renameSymbol.Name, WellKnownMemberNames.GetEnumeratorMethodName, StringComparison.OrdinalIgnoreCase) == 0 ||
string.Compare(renameSymbol.Name, WellKnownMemberNames.GetAwaiter, StringComparison.OrdinalIgnoreCase) == 0 ||
string.Compare(renameSymbol.Name, WellKnownMemberNames.DeconstructMethodName, StringComparison.OrdinalIgnoreCase) == 0));

// TODO: handle Dispose for using statement and Add methods for collection initializers.
Expand All @@ -1013,6 +1014,8 @@ public async Task<ImmutableArray<Location>> ComputeImplicitReferenceConflictsAsy
{
case SyntaxKind.ForEachKeyword:
return ImmutableArray.Create(((CommonForEachStatementSyntax)token.Parent).Expression.GetLocation());
case SyntaxKind.AwaitKeyword:
return ImmutableArray.Create(token.GetLocation());
}

if (token.Parent.IsInDeconstructionLeft(out var deconstructionLeft))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -395,33 +395,42 @@ private static async Task<ImmutableArray<ReferenceLocation>> FindReferencesThrou
return allAliasReferences.ToImmutableAndFree();
}

protected Task<ImmutableArray<Document>> FindDocumentsWithForEachStatementsAsync(Project project, IImmutableSet<Document> documents, CancellationToken cancellationToken)
private Task<ImmutableArray<Document>> FindDocumentsWithPredicateAsync(Project project, IImmutableSet<Document> documents, Func<SyntaxTreeIndex, bool> predicate, CancellationToken cancellationToken)
{
return FindDocumentsAsync(project, documents, async (d, c) =>
{
var info = await SyntaxTreeIndex.GetIndexAsync(d, c).ConfigureAwait(false);
return info.ContainsForEachStatement;
return predicate(info);
}, cancellationToken);
}

protected Task<ImmutableArray<Document>> FindDocumentsWithForEachStatementsAsync(Project project, IImmutableSet<Document> documents, CancellationToken cancellationToken)
=> FindDocumentsWithPredicateAsync(project, documents, predicate: sti => sti.ContainsForEachStatement, cancellationToken);

protected Task<ImmutableArray<Document>> FindDocumentsWithDeconstructionAsync(Project project, IImmutableSet<Document> documents, CancellationToken cancellationToken)
{
return FindDocumentsAsync(project, documents, async (d, c) =>
{
var info = await SyntaxTreeIndex.GetIndexAsync(d, c).ConfigureAwait(false);
return info.ContainsDeconstruction;
}, cancellationToken);
}
=> FindDocumentsWithPredicateAsync(project, documents, predicate: sti => sti.ContainsDeconstruction, cancellationToken);

protected Task<ImmutableArray<Document>> FindDocumentsWithAwaitExpressionAsync(Project project, IImmutableSet<Document> documents, CancellationToken cancellationToken)
=> FindDocumentsWithPredicateAsync(project, documents, predicate: sti => sti.ContainsAwait, cancellationToken);

protected async Task<ImmutableArray<ReferenceLocation>> FindReferencesInForEachStatementsAsync(
/// <summary>
/// If the `node` implicitly matches the `symbol`, then it will be added to `locations`.
/// </summary>
private delegate void CollectMatchingReferences(ISymbol symbol, SyntaxNode node,
ISyntaxFactsService syntaxFacts, ISemanticFactsService semanticFacts, ArrayBuilder<ReferenceLocation> locations);

private async Task<ImmutableArray<ReferenceLocation>> FindReferencesInDocumentAsync(
ISymbol symbol,
Document document,
SemanticModel semanticModel,
Func<SyntaxTreeIndex, bool> isRelevantDocument,
CollectMatchingReferences collectMatchingReferences,
CancellationToken cancellationToken)
{
var syntaxTreeInfo = await SyntaxTreeIndex.GetIndexAsync(document, cancellationToken).ConfigureAwait(false);
if (syntaxTreeInfo.ContainsForEachStatement)
if (isRelevantDocument(syntaxTreeInfo))
{
var syntaxFacts = document.GetLanguageService<ISyntaxFactsService>();
var semanticFacts = document.GetLanguageService<ISemanticFactsService>();
var syntaxRoot = await document.GetSyntaxRootAsync(cancellationToken).ConfigureAwait(false);

Expand All @@ -432,50 +441,57 @@ protected async Task<ImmutableArray<ReferenceLocation>> FindReferencesInForEachS
foreach (var node in syntaxRoot.DescendantNodesAndSelf())
{
cancellationToken.ThrowIfCancellationRequested();
var info = semanticFacts.GetForEachSymbols(semanticModel, node);

if (Matches(info.GetEnumeratorMethod, originalUnreducedSymbolDefinition) ||
Matches(info.MoveNextMethod, originalUnreducedSymbolDefinition) ||
Matches(info.CurrentProperty, originalUnreducedSymbolDefinition) ||
Matches(info.DisposeMethod, originalUnreducedSymbolDefinition))
{
var location = node.GetFirstToken().GetLocation();
locations.Add(new ReferenceLocation(
document, alias: null, location: location, isImplicit: true, isWrittenTo: false, candidateReason: CandidateReason.None));
}
collectMatchingReferences(originalUnreducedSymbolDefinition, node, syntaxFacts, semanticFacts, locations);
}

return locations.ToImmutableAndFree();
}
else
{
return ImmutableArray<ReferenceLocation>.Empty;
}

return ImmutableArray<ReferenceLocation>.Empty;
}

protected async Task<ImmutableArray<ReferenceLocation>> FindReferencesInDeconstructionAsync(
protected Task<ImmutableArray<ReferenceLocation>> FindReferencesInForEachStatementsAsync(
ISymbol symbol,
Document document,
SemanticModel semanticModel,
CancellationToken cancellationToken)
{
var syntaxTreeInfo = await SyntaxTreeIndex.GetIndexAsync(document, cancellationToken).ConfigureAwait(false);
if (!syntaxTreeInfo.ContainsDeconstruction)
return FindReferencesInDocumentAsync(symbol, document, semanticModel, isRelevantDocument, collectMatchingReferences, cancellationToken);

bool isRelevantDocument(SyntaxTreeIndex syntaxTreeInfo)
=> syntaxTreeInfo.ContainsForEachStatement;

void collectMatchingReferences(ISymbol originalUnreducedSymbolDefinition, SyntaxNode node,
ISyntaxFactsService syntaxFacts, ISemanticFactsService semanticFacts, ArrayBuilder<ReferenceLocation> locations)
{
return ImmutableArray<ReferenceLocation>.Empty;
}
var info = semanticFacts.GetForEachSymbols(semanticModel, node);

var syntaxFacts = document.GetLanguageService<ISyntaxFactsService>();
var semanticFacts = document.GetLanguageService<ISemanticFactsService>();
var syntaxRoot = await document.GetSyntaxRootAsync(cancellationToken).ConfigureAwait(false);
if (Matches(info.GetEnumeratorMethod, originalUnreducedSymbolDefinition) ||
Matches(info.MoveNextMethod, originalUnreducedSymbolDefinition) ||
Matches(info.CurrentProperty, originalUnreducedSymbolDefinition) ||
Matches(info.DisposeMethod, originalUnreducedSymbolDefinition))
{
var location = node.GetFirstToken().GetLocation();
locations.Add(new ReferenceLocation(
document, alias: null, location: location, isImplicit: true, isWrittenTo: false, candidateReason: CandidateReason.None));
}
}
}

var locations = ArrayBuilder<ReferenceLocation>.GetInstance();
protected Task<ImmutableArray<ReferenceLocation>> FindReferencesInDeconstructionAsync(
ISymbol symbol,
Document document,
SemanticModel semanticModel,
CancellationToken cancellationToken)
{
return FindReferencesInDocumentAsync(symbol, document, semanticModel, isRelevantDocument, collectMatchingReferences, cancellationToken);

var originalUnreducedSymbolDefinition = symbol.GetOriginalUnreducedDefinition();
bool isRelevantDocument(SyntaxTreeIndex syntaxTreeInfo)
=> syntaxTreeInfo.ContainsDeconstruction;

foreach (var node in syntaxRoot.DescendantNodesAndSelf())
void collectMatchingReferences(ISymbol originalUnreducedSymbolDefinition, SyntaxNode node,
ISyntaxFactsService syntaxFacts, ISemanticFactsService semanticFacts, ArrayBuilder<ReferenceLocation> locations)
{
cancellationToken.ThrowIfCancellationRequested();
var deconstructMethods = semanticFacts.GetDeconstructionAssignmentMethods(semanticModel, node);
if (deconstructMethods.IsEmpty)
{
Expand All @@ -490,8 +506,31 @@ protected async Task<ImmutableArray<ReferenceLocation>> FindReferencesInDeconstr
document, alias: null, location, isImplicit: true, isWrittenTo: false, CandidateReason.None));
}
}
}

return locations.ToImmutableAndFree();
protected Task<ImmutableArray<ReferenceLocation>> FindReferencesInAwaitExpressionAsync(
ISymbol symbol,
Document document,
SemanticModel semanticModel,
CancellationToken cancellationToken)
{
return FindReferencesInDocumentAsync(symbol, document, semanticModel, isRelevantDocument, collectMatchingReferences, cancellationToken);

bool isRelevantDocument(SyntaxTreeIndex syntaxTreeInfo)
=> syntaxTreeInfo.ContainsAwait;

void collectMatchingReferences(ISymbol originalUnreducedSymbolDefinition, SyntaxNode node,
ISyntaxFactsService syntaxFacts, ISemanticFactsService semanticFacts, ArrayBuilder<ReferenceLocation> locations)
{
var awaitExpressionMethod = semanticFacts.GetGetAwaiterMethod(semanticModel, node);

if (Matches(awaitExpressionMethod, originalUnreducedSymbolDefinition))
{
var location = node.GetFirstToken().GetLocation();
locations.Add(new ReferenceLocation(
document, alias: null, location, isImplicit: true, isWrittenTo: false, CandidateReason.None));
}
}
}

private static bool Matches(ISymbol symbol1, ISymbol notNulloriginalUnreducedSymbol2)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,11 @@ protected override async Task<ImmutableArray<Document>> DetermineDocumentsToSear
? await FindDocumentsWithDeconstructionAsync(project, documents, cancellationToken).ConfigureAwait(false)
: ImmutableArray<Document>.Empty;

return ordinaryDocuments.Concat(forEachDocuments).Concat(deconstructDocuments);
var awaitExpressionDocuments = IsGetAwaiterMethod(methodSymbol)
? await FindDocumentsWithAwaitExpressionAsync(project, documents, cancellationToken).ConfigureAwait(false)
: ImmutableArray<Document>.Empty;

return ordinaryDocuments.Concat(forEachDocuments).Concat(deconstructDocuments).Concat(awaitExpressionDocuments);
}

private bool IsForEachMethod(IMethodSymbol methodSymbol)
Expand All @@ -110,6 +114,9 @@ private bool IsForEachMethod(IMethodSymbol methodSymbol)
private bool IsDeconstructMethod(IMethodSymbol methodSymbol)
=> methodSymbol.Name == WellKnownMemberNames.DeconstructMethodName;

private bool IsGetAwaiterMethod(IMethodSymbol methodSymbol)
=> methodSymbol.Name == WellKnownMemberNames.GetAwaiter;

protected override async Task<ImmutableArray<ReferenceLocation>> FindReferencesInDocumentAsync(
IMethodSymbol symbol,
Document document,
Expand All @@ -135,6 +142,12 @@ protected override async Task<ImmutableArray<ReferenceLocation>> FindReferencesI
nameMatches = nameMatches.Concat(deconstructMatches);
}

if (IsGetAwaiterMethod(symbol))
{
var getAwaiterMatches = await FindReferencesInAwaitExpressionAsync(symbol, document, semanticModel, cancellationToken).ConfigureAwait(false);
nameMatches = nameMatches.Concat(getAwaiterMatches);
}

return nameMatches;
}
}
Expand Down
Loading