diff --git a/src/Analyzers/CSharp/CodeFixes/MakeMethodAsynchronous/CSharpMakeMethodAsynchronousCodeFixProvider.cs b/src/Analyzers/CSharp/CodeFixes/MakeMethodAsynchronous/CSharpMakeMethodAsynchronousCodeFixProvider.cs index ebfd5c0ee4ada..e5f700b62bc21 100644 --- a/src/Analyzers/CSharp/CodeFixes/MakeMethodAsynchronous/CSharpMakeMethodAsynchronousCodeFixProvider.cs +++ b/src/Analyzers/CSharp/CodeFixes/MakeMethodAsynchronous/CSharpMakeMethodAsynchronousCodeFixProvider.cs @@ -66,7 +66,7 @@ protected override bool IsAsyncSupportingFunctionSyntax(SyntaxNode node) protected override bool IsAsyncReturnType(ITypeSymbol type, KnownTypes knownTypes) => IsIAsyncEnumerableOrEnumerator(type, knownTypes) || - IsTaskLike(type, knownTypes); + knownTypes.IsTaskLike(type); protected override SyntaxNode AddAsyncTokenAndFixReturnType( bool keepVoid, @@ -129,21 +129,21 @@ private static TypeSyntax FixMethodReturnType( var returnType = methodSymbol.ReturnType; if (IsIEnumerable(returnType, knownTypes) && IsIterator(methodSymbol, cancellationToken)) { - newReturnType = knownTypes.IAsyncEnumerableOfTTypeOpt is null + newReturnType = knownTypes.IAsyncEnumerableOfTType is null ? MakeGenericType(nameof(IAsyncEnumerable), methodSymbol.ReturnType) - : knownTypes.IAsyncEnumerableOfTTypeOpt.Construct(methodSymbol.ReturnType.GetTypeArguments()[0]).GenerateTypeSyntax(); + : knownTypes.IAsyncEnumerableOfTType.Construct(methodSymbol.ReturnType.GetTypeArguments()[0]).GenerateTypeSyntax(); } else if (IsIEnumerator(returnType, knownTypes) && IsIterator(methodSymbol, cancellationToken)) { - newReturnType = knownTypes.IAsyncEnumeratorOfTTypeOpt is null + newReturnType = knownTypes.IAsyncEnumeratorOfTType is null ? MakeGenericType(nameof(IAsyncEnumerator), methodSymbol.ReturnType) - : knownTypes.IAsyncEnumeratorOfTTypeOpt.Construct(methodSymbol.ReturnType.GetTypeArguments()[0]).GenerateTypeSyntax(); + : knownTypes.IAsyncEnumeratorOfTType.Construct(methodSymbol.ReturnType.GetTypeArguments()[0]).GenerateTypeSyntax(); } else if (IsIAsyncEnumerableOrEnumerator(returnType, knownTypes)) { // Leave the return type alone } - else if (!IsTaskLike(returnType, knownTypes)) + else if (!knownTypes.IsTaskLike(returnType)) { // If it's not already Task-like, then wrap the existing return type // in Task<>. @@ -167,8 +167,8 @@ private static bool IsIterator(IMethodSymbol method, CancellationToken cancellat => method.Locations.Any(static (loc, cancellationToken) => loc.FindNode(cancellationToken).ContainsYield(), cancellationToken); private static bool IsIAsyncEnumerableOrEnumerator(ITypeSymbol returnType, KnownTypes knownTypes) - => returnType.OriginalDefinition.Equals(knownTypes.IAsyncEnumerableOfTTypeOpt) || - returnType.OriginalDefinition.Equals(knownTypes.IAsyncEnumeratorOfTTypeOpt); + => returnType.OriginalDefinition.Equals(knownTypes.IAsyncEnumerableOfTType) || + returnType.OriginalDefinition.Equals(knownTypes.IAsyncEnumeratorOfTType); private static bool IsIEnumerable(ITypeSymbol returnType, KnownTypes knownTypes) => returnType.OriginalDefinition.Equals(knownTypes.IEnumerableOfTType); diff --git a/src/Analyzers/CSharp/CodeFixes/MakeMethodSynchronous/CSharpMakeMethodSynchronousCodeFixProvider.cs b/src/Analyzers/CSharp/CodeFixes/MakeMethodSynchronous/CSharpMakeMethodSynchronousCodeFixProvider.cs index eaf4453c6806d..9bbe3ce2585cc 100644 --- a/src/Analyzers/CSharp/CodeFixes/MakeMethodSynchronous/CSharpMakeMethodSynchronousCodeFixProvider.cs +++ b/src/Analyzers/CSharp/CodeFixes/MakeMethodSynchronous/CSharpMakeMethodSynchronousCodeFixProvider.cs @@ -70,13 +70,13 @@ private static TypeSyntax FixMethodReturnType(IMethodSymbol methodSymbol, TypeSy // If the return type is Task, then make the new return type "T". newReturnType = returnType.GetTypeArguments()[0].GenerateTypeSyntax().WithTriviaFrom(returnTypeSyntax); } - else if (returnType.OriginalDefinition.Equals(knownTypes.IAsyncEnumerableOfTTypeOpt) && + else if (returnType.OriginalDefinition.Equals(knownTypes.IAsyncEnumerableOfTType) && knownTypes.IEnumerableOfTType != null) { // If the return type is IAsyncEnumerable, then make the new return type IEnumerable. newReturnType = knownTypes.IEnumerableOfTType.Construct(methodSymbol.ReturnType.GetTypeArguments()[0]).GenerateTypeSyntax(); } - else if (returnType.OriginalDefinition.Equals(knownTypes.IAsyncEnumeratorOfTTypeOpt) && + else if (returnType.OriginalDefinition.Equals(knownTypes.IAsyncEnumeratorOfTType) && knownTypes.IEnumeratorOfTType != null) { // If the return type is IAsyncEnumerator, then make the new return type IEnumerator. diff --git a/src/Analyzers/Core/CodeFixes/MakeMethodAsynchronous/AbstractMakeMethodAsynchronousCodeFixProvider.cs b/src/Analyzers/Core/CodeFixes/MakeMethodAsynchronous/AbstractMakeMethodAsynchronousCodeFixProvider.cs index d1eade8a338a0..d1ebc096c0b19 100644 --- a/src/Analyzers/Core/CodeFixes/MakeMethodAsynchronous/AbstractMakeMethodAsynchronousCodeFixProvider.cs +++ b/src/Analyzers/Core/CodeFixes/MakeMethodAsynchronous/AbstractMakeMethodAsynchronousCodeFixProvider.cs @@ -197,36 +197,5 @@ private async Task AddAsyncTokenAsync( var newDocument = document.WithSyntaxRoot(newRoot); return newDocument.Project.Solution; } - - protected static bool IsTaskLike(ITypeSymbol returnType, KnownTypes knownTypes) - { - if (returnType.Equals(knownTypes.TaskType)) - { - return true; - } - - if (returnType.Equals(knownTypes.ValueTaskType)) - { - return true; - } - - if (returnType.OriginalDefinition.Equals(knownTypes.TaskOfTType)) - { - return true; - } - - if (returnType.OriginalDefinition.Equals(knownTypes.ValueTaskOfTTypeOpt)) - { - return true; - } - - if (returnType.IsErrorType()) - { - return returnType.Name.Equals("Task") || - returnType.Name.Equals("ValueTask"); - } - - return false; - } } } diff --git a/src/Analyzers/Core/CodeFixes/RemoveAsyncModifier/AbstractRemoveAsyncModifierCodeFixProvider.cs b/src/Analyzers/Core/CodeFixes/RemoveAsyncModifier/AbstractRemoveAsyncModifierCodeFixProvider.cs index 2838d373051ce..5d656daf524dd 100644 --- a/src/Analyzers/Core/CodeFixes/RemoveAsyncModifier/AbstractRemoveAsyncModifierCodeFixProvider.cs +++ b/src/Analyzers/Core/CodeFixes/RemoveAsyncModifier/AbstractRemoveAsyncModifierCodeFixProvider.cs @@ -108,7 +108,7 @@ protected sealed override async Task FixAllAsync( private static bool ShouldOfferFix(ITypeSymbol returnType, KnownTypes knownTypes) => IsTaskType(returnType, knownTypes) || returnType.OriginalDefinition.Equals(knownTypes.TaskOfTType) - || returnType.OriginalDefinition.Equals(knownTypes.ValueTaskOfTTypeOpt); + || returnType.OriginalDefinition.Equals(knownTypes.ValueTaskOfTType); private static bool IsTaskType(ITypeSymbol returnType, KnownTypes knownTypes) => returnType.OriginalDefinition.Equals(knownTypes.TaskType) diff --git a/src/Analyzers/VisualBasic/CodeFixes/MakeMethodAsynchronous/VisualBasicMakeMethodAsynchronousCodeFixProvider.vb b/src/Analyzers/VisualBasic/CodeFixes/MakeMethodAsynchronous/VisualBasicMakeMethodAsynchronousCodeFixProvider.vb index ddf0664d0b7a9..9150d855e7357 100644 --- a/src/Analyzers/VisualBasic/CodeFixes/MakeMethodAsynchronous/VisualBasicMakeMethodAsynchronousCodeFixProvider.vb +++ b/src/Analyzers/VisualBasic/CodeFixes/MakeMethodAsynchronous/VisualBasicMakeMethodAsynchronousCodeFixProvider.vb @@ -54,7 +54,7 @@ Namespace Microsoft.CodeAnalysis.VisualBasic.MakeMethodAsynchronous End Function Protected Overrides Function IsAsyncReturnType(type As ITypeSymbol, knownTypes As KnownTypes) As Boolean - Return IsTaskLike(type, knownTypes) + Return knownTypes.IsTaskLike(type) End Function Protected Overrides Function AddAsyncTokenAndFixReturnType( @@ -84,7 +84,7 @@ Namespace Microsoft.CodeAnalysis.VisualBasic.MakeMethodAsynchronous Dim functionStatement = node.SubOrFunctionStatement Dim newFunctionStatement = AddAsyncKeyword(functionStatement) - If Not IsTaskLike(methodSymbol.ReturnType, knownTypes) Then + If Not knownTypes.IsTaskLike(methodSymbol.ReturnType) Then ' if the current return type is not already task-list, then wrap it in Task(of ...) Dim returnType = knownTypes.TaskOfTType.Construct(methodSymbol.ReturnType).GenerateTypeSyntax().WithAdditionalAnnotations(Simplifier.AddImportsAnnotation) newFunctionStatement = newFunctionStatement.WithAsClause( diff --git a/src/Features/CSharpTest/GenerateMethod/GenerateMethodTests.cs b/src/Features/CSharpTest/GenerateMethod/GenerateMethodTests.cs index ac68f1d017ff1..d2db75fde8965 100644 --- a/src/Features/CSharpTest/GenerateMethod/GenerateMethodTests.cs +++ b/src/Features/CSharpTest/GenerateMethod/GenerateMethodTests.cs @@ -8036,7 +8036,9 @@ private static int i() """); } - [Fact, WorkItem("https://github.com/dotnet/roslyn/issues/643")] + [Fact] + [WorkItem("https://github.com/dotnet/roslyn/issues/643")] + [WorkItem("https://github.com/dotnet/roslyn/issues/14467")] public async Task TestGenerateMethodWithConfigureAwaitFalse() { await TestInRegularAndScriptAsync( @@ -8067,7 +8069,7 @@ static void Main(string[] args) bool x = await Goo().ConfigureAwait(false); } - private static Task Goo() + private static async Task Goo() { throw new NotImplementedException(); } @@ -8075,7 +8077,9 @@ private static Task Goo() """); } - [Fact, WorkItem("https://github.com/dotnet/roslyn/issues/643")] + [Fact] + [WorkItem("https://github.com/dotnet/roslyn/issues/643")] + [WorkItem("https://github.com/dotnet/roslyn/issues/14467")] public async Task TestGenerateMethodWithMethodChaining() { await TestInRegularAndScriptAsync( @@ -8106,7 +8110,7 @@ static void Main(string[] args) bool x = await Goo().ConfigureAwait(false); } - private static Task Goo() + private static async Task Goo() { throw new NotImplementedException(); } @@ -8114,7 +8118,9 @@ private static Task Goo() """); } - [Fact, WorkItem("https://github.com/dotnet/roslyn/issues/643")] + [Fact] + [WorkItem("https://github.com/dotnet/roslyn/issues/643")] + [WorkItem("https://github.com/dotnet/roslyn/issues/14467")] public async Task TestGenerateMethodWithMethodChaining2() { await TestInRegularAndScriptAsync( @@ -8149,7 +8155,7 @@ static async void T() }); } - private static Task M() + private static async Task M() { throw new NotImplementedException(); } diff --git a/src/Features/Core/Portable/GenerateMember/GenerateParameterizedMember/AbstractGenerateParameterizedMemberService.SignatureInfo.cs b/src/Features/Core/Portable/GenerateMember/GenerateParameterizedMember/AbstractGenerateParameterizedMemberService.SignatureInfo.cs index f7fff4f3ed6fd..539444c6bc870 100644 --- a/src/Features/Core/Portable/GenerateMember/GenerateParameterizedMember/AbstractGenerateParameterizedMemberService.SignatureInfo.cs +++ b/src/Features/Core/Portable/GenerateMember/GenerateParameterizedMember/AbstractGenerateParameterizedMemberService.SignatureInfo.cs @@ -20,257 +20,259 @@ using Microsoft.CodeAnalysis.Utilities; using Roslyn.Utilities; -namespace Microsoft.CodeAnalysis.GenerateMember.GenerateParameterizedMember +namespace Microsoft.CodeAnalysis.GenerateMember.GenerateParameterizedMember; + +internal abstract partial class AbstractGenerateParameterizedMemberService { - internal abstract partial class AbstractGenerateParameterizedMemberService + internal abstract class SignatureInfo( + SemanticDocument document, + State state) { - internal abstract class SignatureInfo( - SemanticDocument document, - State state) + protected readonly SemanticDocument Document = document; + protected readonly State State = state; + private ImmutableArray _typeParameters; + private IDictionary _typeArgumentToTypeParameterMap; + + public ImmutableArray DetermineTypeParameters(CancellationToken cancellationToken) { - protected readonly SemanticDocument Document = document; - protected readonly State State = state; - private ImmutableArray _typeParameters; - private IDictionary _typeArgumentToTypeParameterMap; + return _typeParameters.IsDefault + ? (_typeParameters = DetermineTypeParametersWorker(cancellationToken)) + : _typeParameters; + } + + protected abstract ImmutableArray DetermineTypeParametersWorker(CancellationToken cancellationToken); + protected abstract RefKind DetermineRefKind(CancellationToken cancellationToken); - public ImmutableArray DetermineTypeParameters(CancellationToken cancellationToken) + public ValueTask DetermineReturnTypeAsync(CancellationToken cancellationToken) + { + var type = DetermineReturnTypeWorker(cancellationToken); + if (State.IsInConditionalAccessExpression) { - return _typeParameters.IsDefault - ? (_typeParameters = DetermineTypeParametersWorker(cancellationToken)) - : _typeParameters; + type = type.RemoveNullableIfPresent(); } - protected abstract ImmutableArray DetermineTypeParametersWorker(CancellationToken cancellationToken); - protected abstract RefKind DetermineRefKind(CancellationToken cancellationToken); + return FixTypeAsync(type, cancellationToken); + } - public ValueTask DetermineReturnTypeAsync(CancellationToken cancellationToken) - { - var type = DetermineReturnTypeWorker(cancellationToken); - if (State.IsInConditionalAccessExpression) - { - type = type.RemoveNullableIfPresent(); - } + protected abstract ImmutableArray DetermineTypeArguments(CancellationToken cancellationToken); + protected abstract ITypeSymbol DetermineReturnTypeWorker(CancellationToken cancellationToken); + protected abstract ImmutableArray DetermineParameterModifiers(CancellationToken cancellationToken); + protected abstract ImmutableArray DetermineParameterTypes(CancellationToken cancellationToken); + protected abstract ImmutableArray DetermineParameterOptionality(CancellationToken cancellationToken); + protected abstract ImmutableArray DetermineParameterNames(CancellationToken cancellationToken); + + internal async ValueTask GeneratePropertyAsync( + SyntaxGenerator factory, + bool isAbstract, bool includeSetter, + CancellationToken cancellationToken) + { + var accessibility = DetermineAccessibility(isAbstract); + var getMethod = CodeGenerationSymbolFactory.CreateAccessorSymbol( + attributes: default, + accessibility: accessibility, + statements: GenerateStatements(factory, isAbstract)); + + var setMethod = includeSetter ? getMethod : null; + + return CodeGenerationSymbolFactory.CreatePropertySymbol( + attributes: default, + accessibility: accessibility, + modifiers: new DeclarationModifiers(isStatic: State.IsStatic, isAbstract: isAbstract), + type: await DetermineReturnTypeAsync(cancellationToken).ConfigureAwait(false), + refKind: DetermineRefKind(cancellationToken), + explicitInterfaceImplementations: default, + name: State.IdentifierToken.ValueText, + parameters: await DetermineParametersAsync(cancellationToken).ConfigureAwait(false), + getMethod: getMethod, + setMethod: setMethod); + } - return FixTypeAsync(type, cancellationToken); + public async ValueTask GenerateMethodAsync( + SyntaxGenerator factory, + bool isAbstract, + CancellationToken cancellationToken) + { + var parameters = await DetermineParametersAsync(cancellationToken).ConfigureAwait(false); + var returnType = await DetermineReturnTypeAsync(cancellationToken).ConfigureAwait(false); + var isUnsafe = false; + if (!State.IsContainedInUnsafeType) + { + isUnsafe = returnType.RequiresUnsafeModifier() || parameters.Any(static p => p.Type.RequiresUnsafeModifier()); } - protected abstract ImmutableArray DetermineTypeArguments(CancellationToken cancellationToken); - protected abstract ITypeSymbol DetermineReturnTypeWorker(CancellationToken cancellationToken); - protected abstract ImmutableArray DetermineParameterModifiers(CancellationToken cancellationToken); - protected abstract ImmutableArray DetermineParameterTypes(CancellationToken cancellationToken); - protected abstract ImmutableArray DetermineParameterOptionality(CancellationToken cancellationToken); - protected abstract ImmutableArray DetermineParameterNames(CancellationToken cancellationToken); - - internal async ValueTask GeneratePropertyAsync( - SyntaxGenerator factory, - bool isAbstract, bool includeSetter, - CancellationToken cancellationToken) - { - var accessibility = DetermineAccessibility(isAbstract); - var getMethod = CodeGenerationSymbolFactory.CreateAccessorSymbol( - attributes: default, - accessibility: accessibility, - statements: GenerateStatements(factory, isAbstract)); + var knownTypes = new KnownTypes(Document.SemanticModel.Compilation); + + var method = CodeGenerationSymbolFactory.CreateMethodSymbol( + attributes: default, + accessibility: DetermineAccessibility(isAbstract), + modifiers: new DeclarationModifiers( + isStatic: State.IsStatic, isAbstract: isAbstract, isUnsafe: isUnsafe, isAsync: knownTypes.IsTaskLike(returnType)), + returnType: returnType, + refKind: DetermineRefKind(cancellationToken), + explicitInterfaceImplementations: default, + name: State.IdentifierToken.ValueText, + typeParameters: DetermineTypeParameters(cancellationToken), + parameters: parameters, + statements: GenerateStatements(factory, isAbstract), + handlesExpressions: default, + returnTypeAttributes: default, + methodKind: State.MethodKind); + + // Ensure no conflicts between type parameter names and parameter names. + var languageServiceProvider = Document.Project.Solution.Services.GetLanguageServices(State.TypeToGenerateIn.Language); + var syntaxFacts = languageServiceProvider.GetService(); + + var equalityComparer = syntaxFacts.StringComparer; + var reservedParameterNames = DetermineParameterNames(cancellationToken) + .Select(p => p.BestNameForParameter) + .ToSet(equalityComparer); + + var newTypeParameterNames = NameGenerator.EnsureUniqueness( + method.TypeParameters.SelectAsArray(t => t.Name), + n => !reservedParameterNames.Contains(n)); + + return method.RenameTypeParameters(newTypeParameterNames); + } - var setMethod = includeSetter ? getMethod : null; + private async ValueTask FixTypeAsync( + ITypeSymbol typeSymbol, + CancellationToken cancellationToken) + { + // A type can't refer to a type parameter that isn't available in the type we're + // eventually generating into. + var availableMethodTypeParameters = DetermineTypeParameters(cancellationToken); + var availableTypeParameters = State.TypeToGenerateIn.GetAllTypeParameters(); + + var compilation = Document.SemanticModel.Compilation; + var allTypeParameters = availableMethodTypeParameters.Concat(availableTypeParameters); + var availableTypeParameterNames = allTypeParameters.Select(t => t.Name).ToSet(); + + var typeArgumentToTypeParameterMap = GetTypeArgumentToTypeParameterMap(cancellationToken); + + typeSymbol = typeSymbol.RemoveAnonymousTypes(compilation); + typeSymbol = await ReplaceTypeParametersBasedOnTypeConstraintsAsync( + Document.Project, typeSymbol, compilation, availableTypeParameterNames, cancellationToken).ConfigureAwait(false); + return typeSymbol.RemoveUnavailableTypeParameters(compilation, allTypeParameters) + .RemoveUnnamedErrorTypes(compilation) + .SubstituteTypes(typeArgumentToTypeParameterMap, new TypeGenerator()); + } - return CodeGenerationSymbolFactory.CreatePropertySymbol( - attributes: default, - accessibility: accessibility, - modifiers: new DeclarationModifiers(isStatic: State.IsStatic, isAbstract: isAbstract), - type: await DetermineReturnTypeAsync(cancellationToken).ConfigureAwait(false), - refKind: DetermineRefKind(cancellationToken), - explicitInterfaceImplementations: default, - name: State.IdentifierToken.ValueText, - parameters: await DetermineParametersAsync(cancellationToken).ConfigureAwait(false), - getMethod: getMethod, - setMethod: setMethod); - } + private IDictionary GetTypeArgumentToTypeParameterMap( + CancellationToken cancellationToken) + { + return _typeArgumentToTypeParameterMap ??= CreateTypeArgumentToTypeParameterMap(cancellationToken); + } - public async ValueTask GenerateMethodAsync( - SyntaxGenerator factory, - bool isAbstract, - CancellationToken cancellationToken) - { - var parameters = await DetermineParametersAsync(cancellationToken).ConfigureAwait(false); - var returnType = await DetermineReturnTypeAsync(cancellationToken).ConfigureAwait(false); - var isUnsafe = false; - if (!State.IsContainedInUnsafeType) - { - isUnsafe = returnType.RequiresUnsafeModifier() || parameters.Any(static p => p.Type.RequiresUnsafeModifier()); - } + private IDictionary CreateTypeArgumentToTypeParameterMap( + CancellationToken cancellationToken) + { + var typeArguments = DetermineTypeArguments(cancellationToken); + var typeParameters = DetermineTypeParameters(cancellationToken); - var method = CodeGenerationSymbolFactory.CreateMethodSymbol( - attributes: default, - accessibility: DetermineAccessibility(isAbstract), - modifiers: new DeclarationModifiers(isStatic: State.IsStatic, isAbstract: isAbstract, isUnsafe: isUnsafe), - returnType: returnType, - refKind: DetermineRefKind(cancellationToken), - explicitInterfaceImplementations: default, - name: State.IdentifierToken.ValueText, - typeParameters: DetermineTypeParameters(cancellationToken), - parameters: parameters, - statements: GenerateStatements(factory, isAbstract), - handlesExpressions: default, - returnTypeAttributes: default, - methodKind: State.MethodKind); - - // Ensure no conflicts between type parameter names and parameter names. - var languageServiceProvider = Document.Project.Solution.Services.GetLanguageServices(State.TypeToGenerateIn.Language); - var syntaxFacts = languageServiceProvider.GetService(); - - var equalityComparer = syntaxFacts.StringComparer; - var reservedParameterNames = DetermineParameterNames(cancellationToken) - .Select(p => p.BestNameForParameter) - .ToSet(equalityComparer); - - var newTypeParameterNames = NameGenerator.EnsureUniqueness( - method.TypeParameters.SelectAsArray(t => t.Name), - n => !reservedParameterNames.Contains(n)); - - return method.RenameTypeParameters(newTypeParameterNames); - } + // We use a nullability-ignoring comparer because top-level and nested nullability won't matter. If we are looking to replace + // IEnumerable with T, we want to replace IEnumerable whenever it appears in an argument or return type, partly because + // there's no way to represent something like T-with-only-the-inner-thing-nullable. We could leave the entire argument as is, but we're suspecting + // this is closer to the user's desire, even if it might require some tweaking after the fact. + var result = new Dictionary(SymbolEqualityComparer.Default); - private async ValueTask FixTypeAsync( - ITypeSymbol typeSymbol, - CancellationToken cancellationToken) + for (var i = 0; i < typeArguments.Length; i++) { - // A type can't refer to a type parameter that isn't available in the type we're - // eventually generating into. - var availableMethodTypeParameters = DetermineTypeParameters(cancellationToken); - var availableTypeParameters = State.TypeToGenerateIn.GetAllTypeParameters(); - - var compilation = Document.SemanticModel.Compilation; - var allTypeParameters = availableMethodTypeParameters.Concat(availableTypeParameters); - var availableTypeParameterNames = allTypeParameters.Select(t => t.Name).ToSet(); - - var typeArgumentToTypeParameterMap = GetTypeArgumentToTypeParameterMap(cancellationToken); - - typeSymbol = typeSymbol.RemoveAnonymousTypes(compilation); - typeSymbol = await ReplaceTypeParametersBasedOnTypeConstraintsAsync( - Document.Project, typeSymbol, compilation, availableTypeParameterNames, cancellationToken).ConfigureAwait(false); - return typeSymbol.RemoveUnavailableTypeParameters(compilation, allTypeParameters) - .RemoveUnnamedErrorTypes(compilation) - .SubstituteTypes(typeArgumentToTypeParameterMap, new TypeGenerator()); + if (typeArguments[i] != null) + { + result[typeArguments[i]] = typeParameters[i]; + } } - private IDictionary GetTypeArgumentToTypeParameterMap( - CancellationToken cancellationToken) - { - return _typeArgumentToTypeParameterMap ??= CreateTypeArgumentToTypeParameterMap(cancellationToken); - } + return result; + } - private IDictionary CreateTypeArgumentToTypeParameterMap( - CancellationToken cancellationToken) - { - var typeArguments = DetermineTypeArguments(cancellationToken); - var typeParameters = DetermineTypeParameters(cancellationToken); + private ImmutableArray GenerateStatements( + SyntaxGenerator factory, + bool isAbstract) + { + var throwStatement = CodeGenerationHelpers.GenerateThrowStatement(factory, Document, "System.NotImplementedException"); - // We use a nullability-ignoring comparer because top-level and nested nullability won't matter. If we are looking to replace - // IEnumerable with T, we want to replace IEnumerable whenever it appears in an argument or return type, partly because - // there's no way to represent something like T-with-only-the-inner-thing-nullable. We could leave the entire argument as is, but we're suspecting - // this is closer to the user's desire, even if it might require some tweaking after the fact. - var result = new Dictionary(SymbolEqualityComparer.Default); + return isAbstract || State.TypeToGenerateIn.TypeKind == TypeKind.Interface || throwStatement == null + ? default + : ImmutableArray.Create(throwStatement); + } - for (var i = 0; i < typeArguments.Length; i++) - { - if (typeArguments[i] != null) - { - result[typeArguments[i]] = typeParameters[i]; - } - } + private async ValueTask> DetermineParametersAsync(CancellationToken cancellationToken) + { + var modifiers = DetermineParameterModifiers(cancellationToken); + var types = await SpecializedTasks.WhenAll(DetermineParameterTypes(cancellationToken).Select(t => FixTypeAsync(t, cancellationToken))).ConfigureAwait(false); + var optionality = DetermineParameterOptionality(cancellationToken); + var names = DetermineParameterNames(cancellationToken); - return result; + using var _ = ArrayBuilder.GetInstance(out var result); + for (var i = 0; i < modifiers.Length; i++) + { + result.Add(CodeGenerationSymbolFactory.CreateParameterSymbol( + attributes: default, + refKind: modifiers[i], + isParams: false, + isOptional: optionality[i], + type: types[i], + name: names[i].BestNameForParameter)); } - private ImmutableArray GenerateStatements( - SyntaxGenerator factory, - bool isAbstract) - { - var throwStatement = CodeGenerationHelpers.GenerateThrowStatement(factory, Document, "System.NotImplementedException"); + return result.ToImmutable(); + } - return isAbstract || State.TypeToGenerateIn.TypeKind == TypeKind.Interface || throwStatement == null - ? default - : ImmutableArray.Create(throwStatement); - } + private Accessibility DetermineAccessibility(bool isAbstract) + { + var containingType = State.ContainingType; - private async ValueTask> DetermineParametersAsync(CancellationToken cancellationToken) + // If we're generating into an interface, then we don't use any modifiers. + if (State.TypeToGenerateIn.TypeKind != TypeKind.Interface) { - var modifiers = DetermineParameterModifiers(cancellationToken); - var types = await SpecializedTasks.WhenAll(DetermineParameterTypes(cancellationToken).Select(t => FixTypeAsync(t, cancellationToken))).ConfigureAwait(false); - var optionality = DetermineParameterOptionality(cancellationToken); - var names = DetermineParameterNames(cancellationToken); - - using var _ = ArrayBuilder.GetInstance(out var result); - for (var i = 0; i < modifiers.Length; i++) + // Otherwise, figure out what accessibility modifier to use and optionally + // mark it as static. + if (containingType.IsContainedWithin(State.TypeToGenerateIn)) { - result.Add(CodeGenerationSymbolFactory.CreateParameterSymbol( - attributes: default, - refKind: modifiers[i], - isParams: false, - isOptional: optionality[i], - type: types[i], - name: names[i].BestNameForParameter)); + return isAbstract ? Accessibility.Protected : Accessibility.Private; } - - return result.ToImmutable(); - } - - private Accessibility DetermineAccessibility(bool isAbstract) - { - var containingType = State.ContainingType; - - // If we're generating into an interface, then we don't use any modifiers. - if (State.TypeToGenerateIn.TypeKind != TypeKind.Interface) + else if (DerivesFrom(containingType) && State.IsStatic) { - // Otherwise, figure out what accessibility modifier to use and optionally - // mark it as static. - if (containingType.IsContainedWithin(State.TypeToGenerateIn)) - { - return isAbstract ? Accessibility.Protected : Accessibility.Private; - } - else if (DerivesFrom(containingType) && State.IsStatic) - { - // NOTE(cyrusn): We only generate protected in the case of statics. Consider - // the case where we're generating into one of our base types. i.e.: - // - // class B : A { void Goo() { A a; a.Goo(); } - // - // In this case we can *not* mark the method as protected. 'B' can only - // access protected members of 'A' through an instance of 'B' (or a subclass - // of B). It can not access protected members through an instance of the - // superclass. In this case we need to make the method public or internal. - // - // However, this does not apply if the method will be static. i.e. - // - // class B : A { void Goo() { A.Goo(); } - // - // B can access the protected statics of A, and so we generate 'Goo' as - // protected. - - // TODO: Code coverage - return Accessibility.Protected; - } - else if (containingType.ContainingAssembly.IsSameAssemblyOrHasFriendAccessTo(State.TypeToGenerateIn.ContainingAssembly)) - { - return Accessibility.Internal; - } - else - { - // TODO: Code coverage - return Accessibility.Public; - } + // NOTE(cyrusn): We only generate protected in the case of statics. Consider + // the case where we're generating into one of our base types. i.e.: + // + // class B : A { void Goo() { A a; a.Goo(); } + // + // In this case we can *not* mark the method as protected. 'B' can only + // access protected members of 'A' through an instance of 'B' (or a subclass + // of B). It can not access protected members through an instance of the + // superclass. In this case we need to make the method public or internal. + // + // However, this does not apply if the method will be static. i.e. + // + // class B : A { void Goo() { A.Goo(); } + // + // B can access the protected statics of A, and so we generate 'Goo' as + // protected. + + // TODO: Code coverage + return Accessibility.Protected; + } + else if (containingType.ContainingAssembly.IsSameAssemblyOrHasFriendAccessTo(State.TypeToGenerateIn.ContainingAssembly)) + { + return Accessibility.Internal; + } + else + { + // TODO: Code coverage + return Accessibility.Public; } - - return Accessibility.NotApplicable; } - private bool DerivesFrom(INamedTypeSymbol containingType) - { - return containingType.GetBaseTypes().Select(t => t.OriginalDefinition) - .OfType() - .Contains(State.TypeToGenerateIn); - } + return Accessibility.NotApplicable; + } + + private bool DerivesFrom(INamedTypeSymbol containingType) + { + return containingType.GetBaseTypes().Select(t => t.OriginalDefinition) + .OfType() + .Contains(State.TypeToGenerateIn); } } } diff --git a/src/Features/VisualBasicTest/GenerateMethod/GenerateMethodTests.vb b/src/Features/VisualBasicTest/GenerateMethod/GenerateMethodTests.vb index 06f6f1da58e79..8786fe45cb85b 100644 --- a/src/Features/VisualBasicTest/GenerateMethod/GenerateMethodTests.vb +++ b/src/Features/VisualBasicTest/GenerateMethod/GenerateMethodTests.vb @@ -2468,6 +2468,7 @@ End Class + Public Async Function TestGenerateMethodForAwaitWithoutParenthesis() As Task Await TestInRegularAndScriptAsync( Module Module1 @@ -2484,7 +2485,7 @@ Module Module1 Dim x = Await Goo End Sub - Private Function Goo() As Task(Of Object) + Private Async Function Goo() As Task(Of Object) Throw New NotImplementedException() End Function End Module @@ -3869,6 +3870,7 @@ End Module") + Public Async Function TestGenerateMethodConfigureAwaitFalse() As Task Await TestInRegularAndScriptAsync( "Imports System @@ -3888,7 +3890,7 @@ Module Program Dim x As Boolean = Await Goo().ConfigureAwait(False) End Sub - Private Function Goo() As Task(Of Boolean) + Private Async Function Goo() As Task(Of Boolean) Throw New NotImplementedException() End Function End Module") @@ -3926,6 +3928,7 @@ index:=1) + Public Async Function TestGenerateMethodWithMethodChaining() As Task Await TestInRegularAndScriptAsync( "Imports System @@ -3943,7 +3946,7 @@ Module M Dim x As Boolean = Await F().ConfigureAwait(False) End Sub - Private Function F() As Task(Of Boolean) + Private Async Function F() As Task(Of Boolean) Throw New NotImplementedException() End Function End Module") diff --git a/src/Workspaces/SharedUtilitiesAndExtensions/Workspace/Core/Extensions/KnownTypes.cs b/src/Workspaces/SharedUtilitiesAndExtensions/Workspace/Core/Extensions/KnownTypes.cs index 43e526a8825fe..df9149e18b02a 100644 --- a/src/Workspaces/SharedUtilitiesAndExtensions/Workspace/Core/Extensions/KnownTypes.cs +++ b/src/Workspaces/SharedUtilitiesAndExtensions/Workspace/Core/Extensions/KnownTypes.cs @@ -4,30 +4,36 @@ namespace Microsoft.CodeAnalysis.Shared.Extensions; -internal readonly struct KnownTypes +internal readonly struct KnownTypes(Compilation compilation) { - public readonly INamedTypeSymbol? TaskType; - public readonly INamedTypeSymbol? TaskOfTType; - public readonly INamedTypeSymbol? ValueTaskType; - public readonly INamedTypeSymbol? ValueTaskOfTTypeOpt; + public readonly INamedTypeSymbol? TaskType = compilation.TaskType(); + public readonly INamedTypeSymbol? TaskOfTType = compilation.TaskOfTType(); + public readonly INamedTypeSymbol? ValueTaskType = compilation.ValueTaskType(); + public readonly INamedTypeSymbol? ValueTaskOfTType = compilation.ValueTaskOfTType(); - public readonly INamedTypeSymbol? IEnumerableOfTType; - public readonly INamedTypeSymbol? IEnumeratorOfTType; + public readonly INamedTypeSymbol? IEnumerableOfTType = compilation.IEnumerableOfTType(); + public readonly INamedTypeSymbol? IEnumeratorOfTType = compilation.IEnumeratorOfTType(); - public readonly INamedTypeSymbol? IAsyncEnumerableOfTTypeOpt; - public readonly INamedTypeSymbol? IAsyncEnumeratorOfTTypeOpt; + public readonly INamedTypeSymbol? IAsyncEnumerableOfTType = compilation.IAsyncEnumerableOfTType(); + public readonly INamedTypeSymbol? IAsyncEnumeratorOfTType = compilation.IAsyncEnumeratorOfTType(); - internal KnownTypes(Compilation compilation) + public bool IsTaskLike(ITypeSymbol returnType) { - TaskType = compilation.TaskType(); - TaskOfTType = compilation.TaskOfTType(); - ValueTaskType = compilation.ValueTaskType(); - ValueTaskOfTTypeOpt = compilation.ValueTaskOfTType(); + if (returnType.Equals(this.TaskType)) + return true; - IEnumerableOfTType = compilation.IEnumerableOfTType(); - IEnumeratorOfTType = compilation.IEnumeratorOfTType(); + if (returnType.Equals(this.ValueTaskType)) + return true; - IAsyncEnumerableOfTTypeOpt = compilation.IAsyncEnumerableOfTType(); - IAsyncEnumeratorOfTTypeOpt = compilation.IAsyncEnumeratorOfTType(); + if (returnType.OriginalDefinition.Equals(this.TaskOfTType)) + return true; + + if (returnType.OriginalDefinition.Equals(this.ValueTaskOfTType)) + return true; + + if (returnType.IsErrorType()) + return returnType.Name is "Task" or "ValueTask"; + + return false; } }