diff --git a/src/Features/LanguageServer/Microsoft.CommonLanguageServerProtocol.Framework.UnitTests/ExampleTests.cs b/src/Features/LanguageServer/Microsoft.CommonLanguageServerProtocol.Framework.UnitTests/ExampleTests.cs index 60c94be8c182d..9fc80b18c4c62 100644 --- a/src/Features/LanguageServer/Microsoft.CommonLanguageServerProtocol.Framework.UnitTests/ExampleTests.cs +++ b/src/Features/LanguageServer/Microsoft.CommonLanguageServerProtocol.Framework.UnitTests/ExampleTests.cs @@ -3,16 +3,8 @@ // See the LICENSE file in the project root for more information. using System; -using System.Collections.Generic; -using System.IO.Pipelines; -using System.Linq; -using System.Text; -using System.Threading; using System.Threading.Tasks; -using Microsoft.CommonLanguageServerProtocol.Framework; using Microsoft.VisualStudio.LanguageServer.Protocol; -using Nerdbank.Streams; -using StreamJsonRpc; using Xunit; namespace Microsoft.CommonLanguageServerProtocol.Framework.UnitTests; diff --git a/src/Features/LanguageServer/Microsoft.CommonLanguageServerProtocol.Framework.UnitTests/HandlerProviderTests.cs b/src/Features/LanguageServer/Microsoft.CommonLanguageServerProtocol.Framework.UnitTests/HandlerProviderTests.cs index f423df0014e20..61d943c592ed8 100644 --- a/src/Features/LanguageServer/Microsoft.CommonLanguageServerProtocol.Framework.UnitTests/HandlerProviderTests.cs +++ b/src/Features/LanguageServer/Microsoft.CommonLanguageServerProtocol.Framework.UnitTests/HandlerProviderTests.cs @@ -4,120 +4,97 @@ using System; using System.Collections.Generic; -using System.Threading; -using System.Threading.Tasks; -using Microsoft.CommonLanguageServerProtocol.Framework; +using System.Linq; using Xunit; namespace Microsoft.CommonLanguageServerProtocol.Framework.UnitTests; -public partial class HandlerProviderTests +public class HandlerProviderTests { - private const string _method = "SomeMethod"; - private const string _wrongMethod = "WrongMethod"; - private static readonly Type _requestType = typeof(int); - private static readonly Type _responseType = typeof(string); - private static readonly Type _wrongResponseType = typeof(long); + [Theory] + [CombinatorialData] + public void GetMethodHandler(bool supportsGetRegisteredServices) + { + var handlerProvider = GetHandlerProvider(supportsGetRegisteredServices); - private static readonly IMethodHandler _expectedMethodHandler = new TestMethodHandler(); + var methodHander = handlerProvider.GetMethodHandler(TestMethodHandler.Name, TestMethodHandler.RequestType, TestMethodHandler.ResponseType); + Assert.Same(TestMethodHandler.Instance, methodHander); + } - [Fact] - public void GetMethodHandler_ViaGetRequiredServices_Succeeds() + [Theory] + [CombinatorialData] + public void GetMethodHandler_Parameterless(bool supportsGetRegisteredServices) { - var handlerProvider = GetHandlerProvider(supportsGetRegisteredServices: false); + var handlerProvider = GetHandlerProvider(supportsGetRegisteredServices); - var methodHander = handlerProvider.GetMethodHandler(_method, _requestType, _responseType); - - Assert.Same(_expectedMethodHandler, methodHander); + var methodHander = handlerProvider.GetMethodHandler(TestParameterlessMethodHandler.Name, requestType: null, TestParameterlessMethodHandler.ResponseType); + Assert.Same(TestParameterlessMethodHandler.Instance, methodHander); } - [Fact] - public void GetMethodHandler_ViaGetRegisteredServices_Succeeds() + [Theory] + [CombinatorialData] + public void GetMethodHandler_Notification(bool supportsGetRegisteredServices) { - var handlerProvider = GetHandlerProvider(supportsGetRegisteredServices: true); + var handlerProvider = GetHandlerProvider(supportsGetRegisteredServices); - var methodHander = handlerProvider.GetMethodHandler(_method, _requestType, _responseType); - - Assert.Same(_expectedMethodHandler, methodHander); + var methodHander = handlerProvider.GetMethodHandler(TestNotificationHandler.Name, TestNotificationHandler.RequestType, responseType: null); + Assert.Same(TestNotificationHandler.Instance, methodHander); } - [Fact] - public void GetMethodHandler_WrongMethod_Throws() + [Theory] + [CombinatorialData] + public void GetMethodHandler_ParameterlessNotification(bool supportsGetRegisteredServices) { - var handlerProvider = GetHandlerProvider(supportsGetRegisteredServices: false); + var handlerProvider = GetHandlerProvider(supportsGetRegisteredServices); - Assert.Throws(() => handlerProvider.GetMethodHandler(_wrongMethod, _requestType, _responseType)); + var methodHander = handlerProvider.GetMethodHandler(TestParameterlessNotificationHandler.Name, requestType: null, responseType: null); + Assert.Same(TestParameterlessNotificationHandler.Instance, methodHander); } [Fact] - public void GetMethodHandler_WrongResponseType_Throws() + public void GetMethodHandler_WrongMethod_Throws() { var handlerProvider = GetHandlerProvider(supportsGetRegisteredServices: false); - Assert.Throws(() => handlerProvider.GetMethodHandler(_method, _requestType, _wrongResponseType)); + Assert.Throws(() => handlerProvider.GetMethodHandler("UndefinedMethod", TestMethodHandler.RequestType, TestMethodHandler.ResponseType)); } [Fact] - public void GetRegisteredMethods_GetRequiredServices() + public void GetMethodHandler_WrongResponseType_Throws() { var handlerProvider = GetHandlerProvider(supportsGetRegisteredServices: false); - var registeredMethods = handlerProvider.GetRegisteredMethods(); - - Assert.Collection(registeredMethods, - (r) => Assert.Equal(_method, r.MethodName)); + Assert.Throws(() => handlerProvider.GetMethodHandler(TestMethodHandler.Name, TestMethodHandler.RequestType, responseType: typeof(long))); } - [Fact] - public void GetRegisteredMethods_GetRegisteredServices() + [Theory] + [CombinatorialData] + public void GetRegisteredMethods(bool supportsGetRegisteredServices) { - var handlerProvider = GetHandlerProvider(supportsGetRegisteredServices: true); + var handlerProvider = GetHandlerProvider(supportsGetRegisteredServices); - var registeredMethods = handlerProvider.GetRegisteredMethods(); + var registeredMethods = handlerProvider.GetRegisteredMethods().OrderBy(m => m.MethodName); Assert.Collection(registeredMethods, - (r) => Assert.Equal(_method, r.MethodName)); + r => Assert.Equal(TestMethodHandler.Name, r.MethodName), + r => Assert.Equal(TestNotificationHandler.Name, r.MethodName), + r => Assert.Equal(TestParameterlessMethodHandler.Name, r.MethodName), + r => Assert.Equal(TestParameterlessNotificationHandler.Name, r.MethodName)); } private static HandlerProvider GetHandlerProvider(bool supportsGetRegisteredServices) - { - var lspServices = GetLspServices(supportsGetRegisteredServices); - var handler = new HandlerProvider(lspServices); - - return handler; - } + => new(GetLspServices(supportsGetRegisteredServices)); - private static ILspServices GetLspServices(bool supportsGetRegisteredServices) + private static TestLspServices GetLspServices(bool supportsGetRegisteredServices) { - var services = new List<(Type, object)> { (typeof(IMethodHandler), _expectedMethodHandler) }; - var lspServices = new TestLspServices(services, supportsGetRegisteredServices); - return lspServices; - } - - [LanguageServerEndpoint(_method)] - internal class TestMethodHandler : IRequestHandler - { - public bool MutatesSolutionState => true; - - public static string Method = _method; - - public static Type RequestType = typeof(int); - - public static Type ResponseType = typeof(string); - - public Task HandleRequestAsync(int request, TestRequestContext context, CancellationToken cancellationToken) + var services = new List<(Type, object)> { - return Task.FromResult("stuff"); - } - } + (typeof(IMethodHandler), TestMethodHandler.Instance), + (typeof(IMethodHandler), TestNotificationHandler.Instance), + (typeof(IMethodHandler), TestParameterlessMethodHandler.Instance), + (typeof(IMethodHandler), TestParameterlessNotificationHandler.Instance), + }; - private class TestMethodHandlerWithoutAttribute : INotificationHandler - { - public bool MutatesSolutionState => true; - - public Task HandleNotificationAsync(TestRequestContext requestContext, CancellationToken cancellationToken) - { - throw new NotImplementedException(); - } + return new TestLspServices(services, supportsGetRegisteredServices); } } diff --git a/src/Features/LanguageServer/Microsoft.CommonLanguageServerProtocol.Framework.UnitTests/NoOpLspLogger.cs b/src/Features/LanguageServer/Microsoft.CommonLanguageServerProtocol.Framework.UnitTests/Mocks/NoOpLspLogger.cs similarity index 88% rename from src/Features/LanguageServer/Microsoft.CommonLanguageServerProtocol.Framework.UnitTests/NoOpLspLogger.cs rename to src/Features/LanguageServer/Microsoft.CommonLanguageServerProtocol.Framework.UnitTests/Mocks/NoOpLspLogger.cs index f833d793e03ee..0264918f7027f 100644 --- a/src/Features/LanguageServer/Microsoft.CommonLanguageServerProtocol.Framework.UnitTests/NoOpLspLogger.cs +++ b/src/Features/LanguageServer/Microsoft.CommonLanguageServerProtocol.Framework.UnitTests/Mocks/NoOpLspLogger.cs @@ -5,13 +5,12 @@ using System; using System.Threading; using System.Threading.Tasks; -using Microsoft.CommonLanguageServerProtocol.Framework; namespace Microsoft.CommonLanguageServerProtocol.Framework.UnitTests; public class NoOpLspLogger : ILspLogger { - public static NoOpLspLogger Instance = new NoOpLspLogger(); + public static NoOpLspLogger Instance = new(); public void LogError(string message, params object[] @params) { diff --git a/src/Features/LanguageServer/Microsoft.CommonLanguageServerProtocol.Framework.UnitTests/Mocks/TestHandlerProvider.cs b/src/Features/LanguageServer/Microsoft.CommonLanguageServerProtocol.Framework.UnitTests/Mocks/TestHandlerProvider.cs new file mode 100644 index 0000000000000..88802ca0d0b6f --- /dev/null +++ b/src/Features/LanguageServer/Microsoft.CommonLanguageServerProtocol.Framework.UnitTests/Mocks/TestHandlerProvider.cs @@ -0,0 +1,24 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Collections.Generic; +using System.Collections.Immutable; +using System.Linq; + +namespace Microsoft.CommonLanguageServerProtocol.Framework.UnitTests; + +internal class TestHandlerProvider : IHandlerProvider +{ + private readonly IEnumerable<(RequestHandlerMetadata metadata, IMethodHandler provider)> _providers; + + public TestHandlerProvider(IEnumerable<(RequestHandlerMetadata metadata, IMethodHandler provider)> providers) + => _providers = providers; + + public IMethodHandler GetMethodHandler(string method, Type? requestType, Type? responseType) + => _providers.Single(p => p.metadata.MethodName == method).provider; + + public ImmutableArray GetRegisteredMethods() + => _providers.Select(p => p.metadata).ToImmutableArray(); +} diff --git a/src/Features/LanguageServer/Microsoft.CommonLanguageServerProtocol.Framework.UnitTests/Mocks/TestLspServices.cs b/src/Features/LanguageServer/Microsoft.CommonLanguageServerProtocol.Framework.UnitTests/Mocks/TestLspServices.cs new file mode 100644 index 0000000000000..7eaad02749422 --- /dev/null +++ b/src/Features/LanguageServer/Microsoft.CommonLanguageServerProtocol.Framework.UnitTests/Mocks/TestLspServices.cs @@ -0,0 +1,41 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Collections.Generic; +using System.Collections.Immutable; +using System.Linq; + +namespace Microsoft.CommonLanguageServerProtocol.Framework.UnitTests; + +internal class TestLspServices : ILspServices +{ + private readonly bool _supportsGetRegisteredServices; + private readonly IEnumerable<(Type type, object instance)> _services; + + public TestLspServices(IEnumerable<(Type type, object instance)> services, bool supportsGetRegisteredServices) + { + _services = services; + _supportsGetRegisteredServices = supportsGetRegisteredServices; + } + + public void Dispose() + { + } + + public ImmutableArray GetRegisteredServices() + => _services.Select(s => s.instance.GetType()).ToImmutableArray(); + + public T GetRequiredService() where T : notnull + => (T?)TryGetService(typeof(T)) ?? throw new InvalidOperationException($"{typeof(T).Name} did not have a service"); + + public IEnumerable GetRequiredServices() + => _supportsGetRegisteredServices ? Array.Empty() : _services.Where(s => s.instance is T).Select(s => (T)s.instance); + + public bool SupportsGetRegisteredServices() + => _supportsGetRegisteredServices; + + public object? TryGetService(Type type) + => _services.FirstOrDefault(s => (_supportsGetRegisteredServices ? s.instance.GetType() : s.type) == type).instance; +} diff --git a/src/Features/LanguageServer/Microsoft.CommonLanguageServerProtocol.Framework.UnitTests/Mocks/TestMethodHandlers.cs b/src/Features/LanguageServer/Microsoft.CommonLanguageServerProtocol.Framework.UnitTests/Mocks/TestMethodHandlers.cs new file mode 100644 index 0000000000000..36df8adba2821 --- /dev/null +++ b/src/Features/LanguageServer/Microsoft.CommonLanguageServerProtocol.Framework.UnitTests/Mocks/TestMethodHandlers.cs @@ -0,0 +1,151 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Threading; +using System.Threading.Tasks; + +namespace Microsoft.CommonLanguageServerProtocol.Framework.UnitTests; + +[LanguageServerEndpoint(Name)] +internal class TestMethodHandler : IRequestHandler +{ + public const string Name = "Method"; + public static readonly IMethodHandler Instance = new TestMethodHandler(); + + public bool MutatesSolutionState => true; + public static Type RequestType = typeof(int); + public static Type ResponseType = typeof(string); + public static RequestHandlerMetadata Metadata = new(Name, RequestType, ResponseType); + + public Task HandleRequestAsync(int request, TestRequestContext context, CancellationToken cancellationToken) + => Task.FromResult("stuff"); +} + +[LanguageServerEndpoint(Name)] +internal class TestParameterlessMethodHandler : IRequestHandler +{ + public const string Name = "ParameterlessMethod"; + public static readonly IMethodHandler Instance = new TestParameterlessMethodHandler(); + + public bool MutatesSolutionState => true; + + public static Type ResponseType = typeof(bool); + public static RequestHandlerMetadata Metadata = new(Name, RequestType: null, ResponseType); + + public Task HandleRequestAsync(TestRequestContext context, CancellationToken cancellationToken) + => Task.FromResult(true); +} + +[LanguageServerEndpoint(Name)] +internal class TestNotificationHandler : INotificationHandler +{ + public const string Name = "Notification"; + public static readonly IMethodHandler Instance = new TestNotificationHandler(); + + public bool MutatesSolutionState => true; + public static Type RequestType = typeof(bool); + public static readonly RequestHandlerMetadata Metadata = new(Name, RequestType, ResponseType: null); + + public Task HandleNotificationAsync(bool request, TestRequestContext context, CancellationToken cancellationToken) + => Task.FromResult(true); +} + +[LanguageServerEndpoint(Name)] +internal class TestParameterlessNotificationHandler : INotificationHandler +{ + public const string Name = "ParameterlessNotification"; + public static readonly IMethodHandler Instance = new TestParameterlessNotificationHandler(); + + public bool MutatesSolutionState => true; + public static readonly RequestHandlerMetadata Metadata = new(Name, RequestType: null, ResponseType: null); + + public Task HandleNotificationAsync(TestRequestContext context, CancellationToken cancellationToken) + => Task.FromResult(true); +} + +internal class TestMethodHandlerWithoutAttribute : INotificationHandler +{ + public bool MutatesSolutionState => true; + + public Task HandleNotificationAsync(TestRequestContext requestContext, CancellationToken cancellationToken) + { + throw new NotImplementedException(); + } +} + +[LanguageServerEndpoint(Name)] +public class MutatingHandler : IRequestHandler +{ + public const string Name = "MutatingMethod"; + public static readonly IMethodHandler Instance = new MutatingHandler(); + public static readonly RequestHandlerMetadata Metadata = new(Name, RequestType: typeof(int), ResponseType: typeof(string)); + + public MutatingHandler() + { + } + + public bool MutatesSolutionState => true; + + public Task HandleRequestAsync(int request, TestRequestContext context, CancellationToken cancellationToken) + { + return Task.FromResult(string.Empty); + } +} + +[LanguageServerEndpoint(Name)] +public class CompletingHandler : IRequestHandler +{ + public const string Name = "CompletingMethod"; + public static readonly IMethodHandler Instance = new CompletingHandler(); + public static readonly RequestHandlerMetadata Metadata = new(Name, RequestType: typeof(int), ResponseType: typeof(string)); + + public bool MutatesSolutionState => false; + + public async Task HandleRequestAsync(int request, TestRequestContext context, CancellationToken cancellationToken) + { + while (true) + { + if (cancellationToken.IsCancellationRequested) + { + return "I completed!"; + } + await Task.Delay(100); + } + } +} + +[LanguageServerEndpoint(Name)] +public class CancellingHandler : IRequestHandler +{ + public const string Name = "CancellingMethod"; + public static readonly IMethodHandler Instance = new CancellingHandler(); + public static readonly RequestHandlerMetadata Metadata = new(Name, RequestType: typeof(int), ResponseType: typeof(string)); + + public bool MutatesSolutionState => false; + + public async Task HandleRequestAsync(int request, TestRequestContext context, CancellationToken cancellationToken) + { + while (true) + { + cancellationToken.ThrowIfCancellationRequested(); + await Task.Delay(100); + } + } +} + +[LanguageServerEndpoint(Name)] +public class ThrowingHandler : IRequestHandler +{ + public const string Name = "ThrowingMethod"; + public static readonly IMethodHandler Instance = new ThrowingHandler(); + public static readonly RequestHandlerMetadata Metadata = new(Name, RequestType: typeof(int), ResponseType: typeof(string)); + + public bool MutatesSolutionState => false; + + public Task HandleRequestAsync(int request, TestRequestContext context, CancellationToken cancellationToken) + { + throw new NotImplementedException(); + } +} diff --git a/src/Features/LanguageServer/Microsoft.CommonLanguageServerProtocol.Framework.UnitTests/Mocks/TestRequestContext.cs b/src/Features/LanguageServer/Microsoft.CommonLanguageServerProtocol.Framework.UnitTests/Mocks/TestRequestContext.cs new file mode 100644 index 0000000000000..1eab00ca80914 --- /dev/null +++ b/src/Features/LanguageServer/Microsoft.CommonLanguageServerProtocol.Framework.UnitTests/Mocks/TestRequestContext.cs @@ -0,0 +1,19 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System.Threading; +using System.Threading.Tasks; + +namespace Microsoft.CommonLanguageServerProtocol.Framework.UnitTests; + +public class TestRequestContext +{ + public class Factory : IRequestContextFactory + { + public static readonly Factory Instance = new(); + + public Task CreateRequestContextAsync(IQueueItem queueItem, TRequestParam requestParam, CancellationToken cancellationToken) + => Task.FromResult(new TestRequestContext()); + } +} diff --git a/src/Features/LanguageServer/Microsoft.CommonLanguageServerProtocol.Framework.UnitTests/RequestExecutionQueueTests.cs b/src/Features/LanguageServer/Microsoft.CommonLanguageServerProtocol.Framework.UnitTests/RequestExecutionQueueTests.cs index 2a0349580a4ac..cb2906f461d3a 100644 --- a/src/Features/LanguageServer/Microsoft.CommonLanguageServerProtocol.Framework.UnitTests/RequestExecutionQueueTests.cs +++ b/src/Features/LanguageServer/Microsoft.CommonLanguageServerProtocol.Framework.UnitTests/RequestExecutionQueueTests.cs @@ -3,19 +3,11 @@ // See the LICENSE file in the project root for more information. using System; -using System.Collections.Generic; -using System.Linq; -using System.Reflection; using System.Threading; using System.Threading.Tasks; -using Microsoft.CodeAnalysis.Elfie.Diagnostics; -using Microsoft.CommonLanguageServerProtocol.Framework; -using Moq; using Nerdbank.Streams; using StreamJsonRpc; using Xunit; -using static Microsoft.CommonLanguageServerProtocol.Framework.UnitTests.HandlerProviderTests; -using static Microsoft.CommonLanguageServerProtocol.Framework.UnitTests.RequestExecutionQueueTests; namespace Microsoft.CommonLanguageServerProtocol.Framework.UnitTests; @@ -33,63 +25,32 @@ protected override ILspServices ConstructLspServices() } } - private const string MethodName = "SomeMethod"; - private const string CancellingMethod = "CancellingMethod"; - private const string CompletingMethod = "CompletingMethod"; - private const string MutatingMethod = "MutatingMethod"; - - private static RequestExecutionQueue GetRequestExecutionQueue(bool cancelInProgressWorkUponMutatingRequest, params IMethodHandler[] methodHandlers) + private static RequestExecutionQueue GetRequestExecutionQueue( + bool cancelInProgressWorkUponMutatingRequest, + params (RequestHandlerMetadata metadata, IMethodHandler handler)[] handlers) { - var handlerProvider = new Mock(MockBehavior.Strict); - if (methodHandlers.Length == 0) - { - var handler = GetTestMethodHandler(); - handlerProvider.Setup(h => h.GetMethodHandler(MethodName, TestMethodHandler.RequestType, TestMethodHandler.ResponseType)).Returns(handler); - } - - foreach (var methodHandler in methodHandlers) - { - var methodType = methodHandler.GetType(); - var methodAttribute = methodType.GetCustomAttribute(); - var method = methodAttribute.Method; + var provider = new TestHandlerProvider(handlers); - handlerProvider.Setup(h => h.GetMethodHandler(method, typeof(int), typeof(string))).Returns(methodHandler); - } - - var executionQueue = new TestRequestExecutionQueue(new MockServer(), NoOpLspLogger.Instance, handlerProvider.Object, cancelInProgressWorkUponMutatingRequest); + var executionQueue = new TestRequestExecutionQueue(new MockServer(), NoOpLspLogger.Instance, provider, cancelInProgressWorkUponMutatingRequest); executionQueue.Start(); return executionQueue; } - private static ILspServices GetLspServices() - { - var requestContextFactory = new Mock>(MockBehavior.Strict); - requestContextFactory.Setup(f => f.CreateRequestContextAsync(It.IsAny>(), It.IsAny(), It.IsAny())) - .Returns(Task.FromResult(new TestRequestContext())); - var services = new List<(Type, object)> { (typeof(IRequestContextFactory), requestContextFactory.Object) }; - var lspServices = new TestLspServices(services, supportsGetRegisteredServices: false); - - return lspServices; - } - - private static TestMethodHandler GetTestMethodHandler() - { - var methodHandler = new TestMethodHandler(); - - return methodHandler; - } + private static TestLspServices GetLspServices() + => new( + services: new[] { (typeof(IRequestContextFactory), (object)TestRequestContext.Factory.Instance) }, + supportsGetRegisteredServices: false); [Fact] public async Task ExecuteAsync_ThrowCompletes() { // Arrange - var throwingHandler = new ThrowingHandler(); - var requestExecutionQueue = GetRequestExecutionQueue(false, throwingHandler); + var requestExecutionQueue = GetRequestExecutionQueue(false, (ThrowingHandler.Metadata, ThrowingHandler.Instance)); var lspServices = GetLspServices(); // Act & Assert - await Assert.ThrowsAsync(() => requestExecutionQueue.ExecuteAsync(1, MethodName, lspServices, CancellationToken.None)); + await Assert.ThrowsAsync(() => requestExecutionQueue.ExecuteAsync(1, ThrowingHandler.Name, lspServices, CancellationToken.None)); } [Fact] @@ -99,21 +60,23 @@ public async Task ExecuteAsync_WithCancelInProgressWork_CancelsInProgressWorkWhe for (var i = 0; i < 20; i++) { // Arrange - var mutatingHandler = new MutatingHandler(); - var cancellingHandler = new CancellingHandler(); - var completingHandler = new CompletingHandler(); - var requestExecutionQueue = GetRequestExecutionQueue(cancelInProgressWorkUponMutatingRequest: true, methodHandlers: new IMethodHandler[] { cancellingHandler, completingHandler, mutatingHandler }); + var requestExecutionQueue = GetRequestExecutionQueue(cancelInProgressWorkUponMutatingRequest: true, handlers: new[] + { + (CancellingHandler.Metadata, CancellingHandler.Instance), + (CompletingHandler.Metadata, CompletingHandler.Instance), + (MutatingHandler.Metadata, MutatingHandler.Instance), + }); var lspServices = GetLspServices(); var cancellingRequestCancellationToken = new CancellationToken(); var completingRequestCancellationToken = new CancellationToken(); - var _ = requestExecutionQueue.ExecuteAsync(1, CancellingMethod, lspServices, cancellingRequestCancellationToken); - var _1 = requestExecutionQueue.ExecuteAsync(1, CompletingMethod, lspServices, completingRequestCancellationToken); + var _ = requestExecutionQueue.ExecuteAsync(1, CancellingHandler.Name, lspServices, cancellingRequestCancellationToken); + var _1 = requestExecutionQueue.ExecuteAsync(1, CompletingHandler.Name, lspServices, completingRequestCancellationToken); // Act & Assert // A Debug.Assert would throw if the tasks hadn't completed when the mutating request is called. - await requestExecutionQueue.ExecuteAsync(1, MutatingMethod, lspServices, CancellationToken.None); + await requestExecutionQueue.ExecuteAsync(1, MutatingHandler.Name, lspServices, CancellationToken.None); } } @@ -121,7 +84,7 @@ public async Task ExecuteAsync_WithCancelInProgressWork_CancelsInProgressWorkWhe public async Task Dispose_MultipleTimes_Succeeds() { // Arrange - var requestExecutionQueue = GetRequestExecutionQueue(false); + var requestExecutionQueue = GetRequestExecutionQueue(false, (TestMethodHandler.Metadata, TestMethodHandler.Instance)); // Act await requestExecutionQueue.DisposeAsync(); @@ -133,100 +96,69 @@ public async Task Dispose_MultipleTimes_Succeeds() [Fact] public async Task ExecuteAsync_CompletesTask() { - var requestExecutionQueue = GetRequestExecutionQueue(false); - var request = 1; + var requestExecutionQueue = GetRequestExecutionQueue(false, (TestMethodHandler.Metadata, TestMethodHandler.Instance)); var lspServices = GetLspServices(); - var response = await requestExecutionQueue.ExecuteAsync(request, MethodName, lspServices, CancellationToken.None); - + var response = await requestExecutionQueue.ExecuteAsync(request: 1, TestMethodHandler.Name, lspServices, CancellationToken.None); Assert.Equal("stuff", response); } [Fact] - public async Task Queue_DrainsOnShutdown() + public async Task ExecuteAsync_CompletesTask_Parameterless() { - var requestExecutionQueue = GetRequestExecutionQueue(false); - var request = 1; + var requestExecutionQueue = GetRequestExecutionQueue(false, (TestParameterlessMethodHandler.Metadata, TestParameterlessMethodHandler.Instance)); var lspServices = GetLspServices(); - var task1 = requestExecutionQueue.ExecuteAsync(request, MethodName, lspServices, CancellationToken.None); - var task2 = requestExecutionQueue.ExecuteAsync(request, MethodName, lspServices, CancellationToken.None); - - await requestExecutionQueue.DisposeAsync(); - - Assert.True(task1.IsCompleted); - Assert.True(task2.IsCompleted); + var response = await requestExecutionQueue.ExecuteAsync(request: NoValue.Instance, TestParameterlessMethodHandler.Name, lspServices, CancellationToken.None); + Assert.True(response); } - private class TestRequestExecutionQueue : RequestExecutionQueue + [Fact] + public async Task ExecuteAsync_CompletesTask_Notification() { - private readonly bool _cancelInProgressWorkUponMutatingRequest; - - public TestRequestExecutionQueue(AbstractLanguageServer languageServer, ILspLogger logger, IHandlerProvider handlerProvider, bool cancelInProgressWorkUponMutatingRequest) - : base(languageServer, logger, handlerProvider) - { - _cancelInProgressWorkUponMutatingRequest = cancelInProgressWorkUponMutatingRequest; - } + var requestExecutionQueue = GetRequestExecutionQueue(false, (TestNotificationHandler.Metadata, TestNotificationHandler.Instance)); + var lspServices = GetLspServices(); - protected override bool CancelInProgressWorkUponMutatingRequest => _cancelInProgressWorkUponMutatingRequest; + var response = await requestExecutionQueue.ExecuteAsync(request: true, TestNotificationHandler.Name, lspServices, CancellationToken.None); + Assert.Same(NoValue.Instance, response); } - [LanguageServerEndpoint(MutatingMethod)] - public class MutatingHandler : IRequestHandler + [Fact] + public async Task ExecuteAsync_CompletesTask_Notification_Parameterless() { - public MutatingHandler() - { - } - - public bool MutatesSolutionState => true; + var requestExecutionQueue = GetRequestExecutionQueue(false, (TestParameterlessNotificationHandler.Metadata, TestParameterlessNotificationHandler.Instance)); + var lspServices = GetLspServices(); - public Task HandleRequestAsync(int request, TestRequestContext context, CancellationToken cancellationToken) - { - return Task.FromResult(string.Empty); - } + var response = await requestExecutionQueue.ExecuteAsync(request: NoValue.Instance, TestParameterlessNotificationHandler.Name, lspServices, CancellationToken.None); + Assert.Same(NoValue.Instance, response); } - [LanguageServerEndpoint(CompletingMethod)] - public class CompletingHandler : IRequestHandler + [Fact] + public async Task Queue_DrainsOnShutdown() { - public bool MutatesSolutionState => false; + var requestExecutionQueue = GetRequestExecutionQueue(false, (TestMethodHandler.Metadata, TestMethodHandler.Instance)); + var request = 1; + var lspServices = GetLspServices(); - public async Task HandleRequestAsync(int request, TestRequestContext context, CancellationToken cancellationToken) - { - while (true) - { - if (cancellationToken.IsCancellationRequested) - { - return "I completed!"; - } - await Task.Delay(100); - } - } - } + var task1 = requestExecutionQueue.ExecuteAsync(request, TestMethodHandler.Name, lspServices, CancellationToken.None); + var task2 = requestExecutionQueue.ExecuteAsync(request, TestMethodHandler.Name, lspServices, CancellationToken.None); - [LanguageServerEndpoint(CancellingMethod)] - public class CancellingHandler : IRequestHandler - { - public bool MutatesSolutionState => false; + await requestExecutionQueue.DisposeAsync(); - public async Task HandleRequestAsync(int request, TestRequestContext context, CancellationToken cancellationToken) - { - while (true) - { - cancellationToken.ThrowIfCancellationRequested(); - await Task.Delay(100); - } - } + Assert.True(task1.IsCompleted); + Assert.True(task2.IsCompleted); } - [LanguageServerEndpoint(MethodName)] - public class ThrowingHandler : IRequestHandler + private class TestRequestExecutionQueue : RequestExecutionQueue { - public bool MutatesSolutionState => false; + private readonly bool _cancelInProgressWorkUponMutatingRequest; - public Task HandleRequestAsync(int request, TestRequestContext context, CancellationToken cancellationToken) + public TestRequestExecutionQueue(AbstractLanguageServer languageServer, ILspLogger logger, IHandlerProvider handlerProvider, bool cancelInProgressWorkUponMutatingRequest) + : base(languageServer, logger, handlerProvider) { - throw new NotImplementedException(); + _cancelInProgressWorkUponMutatingRequest = cancelInProgressWorkUponMutatingRequest; } + + protected override bool CancelInProgressWorkUponMutatingRequest => _cancelInProgressWorkUponMutatingRequest; } } diff --git a/src/Features/LanguageServer/Microsoft.CommonLanguageServerProtocol.Framework.UnitTests/TestLspServices.cs b/src/Features/LanguageServer/Microsoft.CommonLanguageServerProtocol.Framework.UnitTests/TestLspServices.cs deleted file mode 100644 index 631b2ce3376c0..0000000000000 --- a/src/Features/LanguageServer/Microsoft.CommonLanguageServerProtocol.Framework.UnitTests/TestLspServices.cs +++ /dev/null @@ -1,66 +0,0 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. -// See the LICENSE file in the project root for more information. - -using System; -using System.Collections.Generic; -using System.Collections.Immutable; -using System.Linq; -using Microsoft.CommonLanguageServerProtocol.Framework; - -namespace Microsoft.CommonLanguageServerProtocol.Framework.UnitTests; - -internal class TestLspServices : ILspServices -{ - private readonly bool _supportsGetRegisteredServices; - private readonly IEnumerable<(Type, object)> _services; - - public TestLspServices(IEnumerable<(Type, object)> services, bool supportsGetRegisteredServices) - { - _services = services; - _supportsGetRegisteredServices = supportsGetRegisteredServices; - } - - public void Dispose() - { - throw new NotImplementedException(); - } - - public ImmutableArray GetRegisteredServices() - { - var types = new List(); - foreach (var service in _services) - { - types.Add(service.Item2.GetType()); - } - - return types.ToImmutableArray(); - } - - public T GetRequiredService() where T : notnull - { - var service = (T?)TryGetService(typeof(T)); - if (service is null) - throw new InvalidOperationException($"{typeof(T).Name} did not have a service"); - - return service; - } - - public IEnumerable GetRequiredServices() - { - var services = _services.Where(s => !_supportsGetRegisteredServices && s.Item2 is IMethodHandler).Select(s => (T)s.Item2); - return services; - } - - public bool SupportsGetRegisteredServices() - { - return _supportsGetRegisteredServices; - } - - public object? TryGetService(Type type) - { - var service = _services.FirstOrDefault(s => (_supportsGetRegisteredServices ? s.Item2.GetType() : s.Item1) == type); - - return service.Item2; - } -} diff --git a/src/Features/LanguageServer/Microsoft.CommonLanguageServerProtocol.Framework.UnitTests/TestRequestContext.cs b/src/Features/LanguageServer/Microsoft.CommonLanguageServerProtocol.Framework.UnitTests/TestRequestContext.cs deleted file mode 100644 index 7fa0b3383df60..0000000000000 --- a/src/Features/LanguageServer/Microsoft.CommonLanguageServerProtocol.Framework.UnitTests/TestRequestContext.cs +++ /dev/null @@ -1,9 +0,0 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. -// See the LICENSE file in the project root for more information. - -namespace Microsoft.CommonLanguageServerProtocol.Framework.UnitTests; - -public class TestRequestContext -{ -} diff --git a/src/Features/LanguageServer/Microsoft.CommonLanguageServerProtocol.Framework/AbstractLanguageServer.cs b/src/Features/LanguageServer/Microsoft.CommonLanguageServerProtocol.Framework/AbstractLanguageServer.cs index 6e740b91f2622..7196583560a2c 100644 --- a/src/Features/LanguageServer/Microsoft.CommonLanguageServerProtocol.Framework/AbstractLanguageServer.cs +++ b/src/Features/LanguageServer/Microsoft.CommonLanguageServerProtocol.Framework/AbstractLanguageServer.cs @@ -93,17 +93,6 @@ protected virtual IHandlerProvider GetHandlerProvider() protected virtual void SetupRequestDispatcher(IHandlerProvider handlerProvider) { - var entryPointMethod = typeof(DelegatingEntryPoint).GetMethod(nameof(DelegatingEntryPoint.EntryPointAsync)); - if (entryPointMethod is null) - throw new InvalidOperationException($"{typeof(DelegatingEntryPoint).FullName} is missing method {nameof(DelegatingEntryPoint.EntryPointAsync)}"); - var notificationMethod = typeof(DelegatingEntryPoint).GetMethod(nameof(DelegatingEntryPoint.NotificationEntryPointAsync)); - if (notificationMethod is null) - throw new InvalidOperationException($"{typeof(DelegatingEntryPoint).FullName} is missing method {nameof(DelegatingEntryPoint.NotificationEntryPointAsync)}"); - - var parameterlessNotificationMethod = typeof(DelegatingEntryPoint).GetMethod(nameof(DelegatingEntryPoint.ParameterlessNotificationEntryPointAsync)); - if (parameterlessNotificationMethod is null) - throw new InvalidOperationException($"{typeof(DelegatingEntryPoint).FullName} is missing method {nameof(DelegatingEntryPoint.ParameterlessNotificationEntryPointAsync)}"); - foreach (var metadata in handlerProvider.GetRegisteredMethods()) { // Instead of concretely defining methods for each LSP method, we instead dynamically construct the @@ -113,31 +102,17 @@ protected virtual void SetupRequestDispatcher(IHandlerProvider handlerProvider) // // We also do not use the StreamJsonRpc support for JToken as the rpc method parameters because we want // StreamJsonRpc to do the deserialization to handle streaming requests using IProgress. + + var method = DelegatingEntryPoint.GetMethodInstantiation(metadata.RequestType, metadata.ResponseType); + var delegatingEntryPoint = new DelegatingEntryPoint(metadata.MethodName, this); - MethodInfo genericEntryPointMethod; - if (metadata.RequestType is not null && metadata.ResponseType is not null) - { - genericEntryPointMethod = entryPointMethod.MakeGenericMethod(metadata.RequestType, metadata.ResponseType); - } - else if (metadata.RequestType is not null && metadata.ResponseType is null) - { - genericEntryPointMethod = notificationMethod.MakeGenericMethod(metadata.RequestType); - } - else if (metadata.RequestType is null && metadata.ResponseType is null) - { - // No need to genericize - genericEntryPointMethod = parameterlessNotificationMethod; - } - else - { - throw new NotImplementedException($"An unrecognized {nameof(RequestHandlerMetadata)} situation has occured"); - } var methodAttribute = new JsonRpcMethodAttribute(metadata.MethodName) { UseSingleObjectParameterDeserialization = true, }; - _jsonRpc.AddLocalRpcMethod(genericEntryPointMethod, delegatingEntryPoint, methodAttribute); + + _jsonRpc.AddLocalRpcMethod(method, delegatingEntryPoint, methodAttribute); } } @@ -171,23 +146,37 @@ protected IRequestExecutionQueue GetRequestExecutionQueue() /// Wrapper class to hold the method and properties from the /// that the method info passed to StreamJsonRpc is created from. /// - private class DelegatingEntryPoint + private sealed class DelegatingEntryPoint { private readonly string _method; private readonly AbstractLanguageServer _target; + private static readonly MethodInfo s_entryPointMethod = typeof(DelegatingEntryPoint).GetMethod(nameof(EntryPointAsync)); + private static readonly MethodInfo s_parameterlessEntryPointMethod = typeof(DelegatingEntryPoint).GetMethod(nameof(ParameterlessEntryPointAsync)); + private static readonly MethodInfo s_notificationMethod = typeof(DelegatingEntryPoint).GetMethod(nameof(NotificationEntryPointAsync)); + private static readonly MethodInfo s_parameterlessNotificationMethod = typeof(DelegatingEntryPoint).GetMethod(nameof(ParameterlessNotificationEntryPointAsync)); + public DelegatingEntryPoint(string method, AbstractLanguageServer target) { _method = method; _target = target; } + public static MethodInfo GetMethodInstantiation(Type? requestType, Type? responseType) + => (requestType, responseType) switch + { + (requestType: not null, responseType: not null) => s_entryPointMethod.MakeGenericMethod(requestType, responseType), + (requestType: null, responseType: not null) => s_parameterlessEntryPointMethod.MakeGenericMethod(responseType), + (requestType: not null, responseType: null) => s_notificationMethod.MakeGenericMethod(requestType), + (requestType: null, responseType: null) => s_parameterlessNotificationMethod, + }; + public async Task NotificationEntryPointAsync(TRequest request, CancellationToken cancellationToken) where TRequest : class { var queue = _target.GetRequestExecutionQueue(); var lspServices = _target.GetLspServices(); - _ = await queue.ExecuteAsync(request, _method, lspServices, cancellationToken).ConfigureAwait(false); + _ = await queue.ExecuteAsync(request, _method, lspServices, cancellationToken).ConfigureAwait(false); } public async Task ParameterlessNotificationEntryPointAsync(CancellationToken cancellationToken) @@ -195,7 +184,7 @@ public async Task ParameterlessNotificationEntryPointAsync(CancellationToken can var queue = _target.GetRequestExecutionQueue(); var lspServices = _target.GetLspServices(); - _ = await queue.ExecuteAsync(VoidReturn.Instance, _method, lspServices, cancellationToken).ConfigureAwait(false); + _ = await queue.ExecuteAsync(NoValue.Instance, _method, lspServices, cancellationToken).ConfigureAwait(false); } public async Task EntryPointAsync(TRequest request, CancellationToken cancellationToken) where TRequest : class @@ -207,6 +196,16 @@ public async Task ParameterlessNotificationEntryPointAsync(CancellationToken can return result; } + + public async Task ParameterlessEntryPointAsync(CancellationToken cancellationToken) + { + var queue = _target.GetRequestExecutionQueue(); + var lspServices = _target.GetLspServices(); + + var result = await queue.ExecuteAsync(NoValue.Instance, _method, lspServices, cancellationToken).ConfigureAwait(false); + + return result; + } } public Task WaitForExitAsync() diff --git a/src/Features/LanguageServer/Microsoft.CommonLanguageServerProtocol.Framework/HandlerProvider.cs b/src/Features/LanguageServer/Microsoft.CommonLanguageServerProtocol.Framework/HandlerProvider.cs index d4ec2d69407ef..9d79e5d74451c 100644 --- a/src/Features/LanguageServer/Microsoft.CommonLanguageServerProtocol.Framework/HandlerProvider.cs +++ b/src/Features/LanguageServer/Microsoft.CommonLanguageServerProtocol.Framework/HandlerProvider.cs @@ -5,7 +5,7 @@ using System; using System.Collections.Generic; using System.Collections.Immutable; -using System.Diagnostics.Contracts; +using System.Diagnostics; using System.Linq; using System.Reflection; using System.Threading; @@ -40,9 +40,8 @@ public IMethodHandler GetMethodHandler(string method, Type? requestType, Type? r { throw new InvalidOperationException($"Missing handler for {requestHandlerMetadata.MethodName}"); } - var handler = lazyHandler.Value; - return handler; + return lazyHandler.Value; } public ImmutableArray GetRegisteredMethods() @@ -52,14 +51,7 @@ public ImmutableArray GetRegisteredMethods() } private ImmutableDictionary> GetRequestHandlers() - { - if (_requestHandlers is null) - { - _requestHandlers = CreateMethodToHandlerMap(_lspServices); - } - - return _requestHandlers; - } + => _requestHandlers ??= CreateMethodToHandlerMap(_lspServices); private static ImmutableDictionary> CreateMethodToHandlerMap(ILspServices lspServices) { @@ -69,7 +61,7 @@ private static ImmutableDictionary> if (lspServices.SupportsGetRegisteredServices()) { - var requestHandlerTypes = lspServices.GetRegisteredServices().Where(type => IsTypeRequestHandler(type)); + var requestHandlerTypes = lspServices.GetRegisteredServices().Where(type => typeof(IMethodHandler).IsAssignableFrom(type)); foreach (var handlerType in requestHandlerTypes) { @@ -178,11 +170,6 @@ static string GetRequestHandlerMethod(Type handlerType, Type? requestType, Type } } - static bool IsTypeRequestHandler(Type type) - { - return type.GetInterfaces().Contains(typeof(IMethodHandler)); - } - static void VerifyHandlers(IEnumerable requestHandlerKeys) { var missingMethods = requestHandlerKeys.Where(meta => RequiredMethods.All(method => method == meta.MethodName)); @@ -202,59 +189,51 @@ private record HandlerTypes(Type? RequestType, Type? ResponseType, Type RequestC /// private static List ConvertHandlerTypeToRequestResponseTypes(Type handlerType) { - var genericInterfaces = handlerType.GetInterfaces().Where(i => i.IsGenericType); - var requestHandlerGenericTypes = GetGenericTypes(genericInterfaces, typeof(IRequestHandler<,,>)); - var parameterlessNotificationHandlerGenericTypes = GetGenericTypes(genericInterfaces, typeof(INotificationHandler<>)); - var notificationHandlerGenericTypes = GetGenericTypes(genericInterfaces, typeof(INotificationHandler<,>)); - var handlerList = new List(); - foreach (var requestHandlerGenericType in requestHandlerGenericTypes) + foreach (var interfaceType in handlerType.GetInterfaces()) { - var genericArguments = requestHandlerGenericType.GetGenericArguments(); - - if (genericArguments.Length != 3) + if (!interfaceType.IsGenericType) { - throw new InvalidOperationException($"Provided handler type {handlerType.FullName} does not have exactly three generic arguments"); + continue; } - handlerList.Add(new HandlerTypes(RequestType: genericArguments[0], ResponseType: genericArguments[1], RequestContext: genericArguments[2])); - } + var genericDefinition = interfaceType.GetGenericTypeDefinition(); - foreach (var parameterlessNotificationHandlerGenericType in parameterlessNotificationHandlerGenericTypes) - { - var genericArguments = parameterlessNotificationHandlerGenericType.GetGenericArguments(); - - if (genericArguments.Length != 1) + HandlerTypes types; + if (genericDefinition == typeof(IRequestHandler<,,>)) { - throw new InvalidOperationException($"Provided handler type {handlerType.FullName} does not have exactly 1 generic argument"); + var genericArguments = interfaceType.GetGenericArguments(); + types = new HandlerTypes(RequestType: genericArguments[0], ResponseType: genericArguments[1], RequestContext: genericArguments[2]); } - - handlerList.Add(new HandlerTypes(RequestType: null, ResponseType: null, RequestContext: genericArguments[0])); - } - - foreach (var notificationHandlerGenericType in notificationHandlerGenericTypes) - { - var genericArguments = notificationHandlerGenericType.GetGenericArguments(); - - if (genericArguments.Length != 2) + else if (genericDefinition == typeof(IRequestHandler<,>)) { - throw new InvalidOperationException($"Provided handler type {handlerType.FullName} does not have exactly 2 generic arguments"); + var genericArguments = interfaceType.GetGenericArguments(); + types = new HandlerTypes(RequestType: null, ResponseType: genericArguments[0], RequestContext: genericArguments[1]); + } + else if (genericDefinition == typeof(INotificationHandler<,>)) + { + var genericArguments = interfaceType.GetGenericArguments(); + types = new HandlerTypes(RequestType: genericArguments[0], ResponseType: null, RequestContext: genericArguments[1]); + } + else if (genericDefinition == typeof(INotificationHandler<>)) + { + var genericArguments = interfaceType.GetGenericArguments(); + types = new HandlerTypes(RequestType: null, ResponseType: null, RequestContext: genericArguments[0]); + } + else + { + continue; } - handlerList.Add(new HandlerTypes(RequestType: genericArguments[0], ResponseType: null, RequestContext: genericArguments[1])); + handlerList.Add(types); } - if (!handlerList.Any()) + if (handlerList.Count == 0) { - throw new InvalidOperationException($"Provided handler type {handlerType.FullName} does not implement {typeof(IRequestHandler<,,>).Name}, {typeof(INotificationHandler<>).Name} or {typeof(INotificationHandler<,>).Name}"); + throw new InvalidOperationException($"Provided handler type {handlerType.FullName} does not implement {nameof(IMethodHandler)}"); } return handlerList; - - static IEnumerable GetGenericTypes(IEnumerable genericInterfaces, Type methodHandlerType) - { - return genericInterfaces.Where(i => i.GetGenericTypeDefinition() == methodHandlerType); - } } } diff --git a/src/Features/LanguageServer/Microsoft.CommonLanguageServerProtocol.Framework/QueueItem.cs b/src/Features/LanguageServer/Microsoft.CommonLanguageServerProtocol.Framework/QueueItem.cs index 1a7049af95c7e..8dcb9e1f04a41 100644 --- a/src/Features/LanguageServer/Microsoft.CommonLanguageServerProtocol.Framework/QueueItem.cs +++ b/src/Features/LanguageServer/Microsoft.CommonLanguageServerProtocol.Framework/QueueItem.cs @@ -11,11 +11,11 @@ namespace Microsoft.CommonLanguageServerProtocol.Framework; /// -/// A placeholder type to help handle Notification messages. +/// A placeholder type to help handle parameterless messages and messages with no return value. /// -internal record VoidReturn +internal sealed class NoValue { - public static VoidReturn Instance = new(); + public static NoValue Instance = new(); } internal class QueueItem : IQueueItem @@ -119,32 +119,38 @@ public async Task StartRequestAsync(TRequestContext? context, CancellationToken _logger.LogWarning($"Could not get request context for {MethodName}"); _completionSource.TrySetException(new InvalidOperationException($"Unable to create request context for {MethodName}")); } + else if (_handler is IRequestHandler requestHandler) + { + var result = await requestHandler.HandleRequestAsync(_request, context, cancellationToken).ConfigureAwait(false); + + _completionSource.TrySetResult(result); + } + else if (_handler is IRequestHandler parameterlessRequestHandler) + { + var result = await parameterlessRequestHandler.HandleRequestAsync(context, cancellationToken).ConfigureAwait(false); + + _completionSource.TrySetResult(result); + } + else if (_handler is INotificationHandler notificationHandler) + { + await notificationHandler.HandleNotificationAsync(_request, context, cancellationToken).ConfigureAwait(false); + + // We know that the return type of will always be even if the compiler doesn't. + _completionSource.TrySetResult((TResponse)(object)NoValue.Instance); + } + else if (_handler is INotificationHandler parameterlessNotificationHandler) + { + await parameterlessNotificationHandler.HandleNotificationAsync(context, cancellationToken).ConfigureAwait(false); + + // We know that the return type of will always be even if the compiler doesn't. + _completionSource.TrySetResult((TResponse)(object)NoValue.Instance); + } else { - if (_handler is IRequestHandler requestHandler) - { - var result = await requestHandler.HandleRequestAsync(_request, context, cancellationToken).ConfigureAwait(false); - - _completionSource.TrySetResult(result); - } - else if (_handler is INotificationHandler notificationHandler) - { - await notificationHandler.HandleNotificationAsync(_request, context, cancellationToken).ConfigureAwait(false); - - // We know that the return type of will always be even if the compiler doesn't. - _completionSource.TrySetResult((TResponse)(object)VoidReturn.Instance); - } - else if (_handler is INotificationHandler parameterlessNotificationHandler) - { - await parameterlessNotificationHandler.HandleNotificationAsync(context, cancellationToken).ConfigureAwait(false); - - // We know that the return type of will always be even if the compiler doesn't. - _completionSource.TrySetResult((TResponse)(object)VoidReturn.Instance); - } - else - { - throw new NotImplementedException($"Unrecognized {nameof(IMethodHandler)} implementation {_handler.GetType().Name}"); - } + throw new NotImplementedException( + $"Unrecognized {nameof(IMethodHandler)} implementation {_handler.GetType()}. " + + $"TRequest is {typeof(TRequest)}. " + + $"TResponse is {typeof(TResponse)}."); } } catch (OperationCanceledException ex) diff --git a/src/Features/LanguageServer/Microsoft.CommonLanguageServerProtocol.Framework/RequestExecutionQueue.cs b/src/Features/LanguageServer/Microsoft.CommonLanguageServerProtocol.Framework/RequestExecutionQueue.cs index d37e859ee6f86..17d76febb3871 100644 --- a/src/Features/LanguageServer/Microsoft.CommonLanguageServerProtocol.Framework/RequestExecutionQueue.cs +++ b/src/Features/LanguageServer/Microsoft.CommonLanguageServerProtocol.Framework/RequestExecutionQueue.cs @@ -82,8 +82,8 @@ public void Start() protected IMethodHandler GetMethodHandler(string methodName) { - var requestType = typeof(TRequest) == typeof(VoidReturn) ? null : typeof(TRequest); - var responseType = typeof(TResponse) == typeof(VoidReturn) ? null : typeof(TResponse); + var requestType = typeof(TRequest) == typeof(NoValue) ? null : typeof(TRequest); + var responseType = typeof(TResponse) == typeof(NoValue) ? null : typeof(TResponse); var handler = _handlerProvider.GetMethodHandler(methodName, requestType, responseType);