Skip to content

Commit

Permalink
Fix up implementation for parameterless requests. (#68077)
Browse files Browse the repository at this point in the history
  • Loading branch information
tmat authored May 4, 2023
1 parent f6d9c05 commit f731f97
Show file tree
Hide file tree
Showing 14 changed files with 441 additions and 397 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,8 @@
// See the LICENSE file in the project root for more information.

using System;
using System.Collections.Generic;
using System.IO.Pipelines;
using System.Linq;
using System.Text;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.CommonLanguageServerProtocol.Framework;
using Microsoft.VisualStudio.LanguageServer.Protocol;
using Nerdbank.Streams;
using StreamJsonRpc;
using Xunit;

namespace Microsoft.CommonLanguageServerProtocol.Framework.UnitTests;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,120 +4,97 @@

using System;
using System.Collections.Generic;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.CommonLanguageServerProtocol.Framework;
using System.Linq;
using Xunit;

namespace Microsoft.CommonLanguageServerProtocol.Framework.UnitTests;

public partial class HandlerProviderTests
public class HandlerProviderTests
{
private const string _method = "SomeMethod";
private const string _wrongMethod = "WrongMethod";
private static readonly Type _requestType = typeof(int);
private static readonly Type _responseType = typeof(string);
private static readonly Type _wrongResponseType = typeof(long);
[Theory]
[CombinatorialData]
public void GetMethodHandler(bool supportsGetRegisteredServices)
{
var handlerProvider = GetHandlerProvider(supportsGetRegisteredServices);

private static readonly IMethodHandler _expectedMethodHandler = new TestMethodHandler();
var methodHander = handlerProvider.GetMethodHandler(TestMethodHandler.Name, TestMethodHandler.RequestType, TestMethodHandler.ResponseType);
Assert.Same(TestMethodHandler.Instance, methodHander);
}

[Fact]
public void GetMethodHandler_ViaGetRequiredServices_Succeeds()
[Theory]
[CombinatorialData]
public void GetMethodHandler_Parameterless(bool supportsGetRegisteredServices)
{
var handlerProvider = GetHandlerProvider(supportsGetRegisteredServices: false);
var handlerProvider = GetHandlerProvider(supportsGetRegisteredServices);

var methodHander = handlerProvider.GetMethodHandler(_method, _requestType, _responseType);

Assert.Same(_expectedMethodHandler, methodHander);
var methodHander = handlerProvider.GetMethodHandler(TestParameterlessMethodHandler.Name, requestType: null, TestParameterlessMethodHandler.ResponseType);
Assert.Same(TestParameterlessMethodHandler.Instance, methodHander);
}

[Fact]
public void GetMethodHandler_ViaGetRegisteredServices_Succeeds()
[Theory]
[CombinatorialData]
public void GetMethodHandler_Notification(bool supportsGetRegisteredServices)
{
var handlerProvider = GetHandlerProvider(supportsGetRegisteredServices: true);
var handlerProvider = GetHandlerProvider(supportsGetRegisteredServices);

var methodHander = handlerProvider.GetMethodHandler(_method, _requestType, _responseType);

Assert.Same(_expectedMethodHandler, methodHander);
var methodHander = handlerProvider.GetMethodHandler(TestNotificationHandler.Name, TestNotificationHandler.RequestType, responseType: null);
Assert.Same(TestNotificationHandler.Instance, methodHander);
}

[Fact]
public void GetMethodHandler_WrongMethod_Throws()
[Theory]
[CombinatorialData]
public void GetMethodHandler_ParameterlessNotification(bool supportsGetRegisteredServices)
{
var handlerProvider = GetHandlerProvider(supportsGetRegisteredServices: false);
var handlerProvider = GetHandlerProvider(supportsGetRegisteredServices);

Assert.Throws<InvalidOperationException>(() => handlerProvider.GetMethodHandler(_wrongMethod, _requestType, _responseType));
var methodHander = handlerProvider.GetMethodHandler(TestParameterlessNotificationHandler.Name, requestType: null, responseType: null);
Assert.Same(TestParameterlessNotificationHandler.Instance, methodHander);
}

[Fact]
public void GetMethodHandler_WrongResponseType_Throws()
public void GetMethodHandler_WrongMethod_Throws()
{
var handlerProvider = GetHandlerProvider(supportsGetRegisteredServices: false);

Assert.Throws<InvalidOperationException>(() => handlerProvider.GetMethodHandler(_method, _requestType, _wrongResponseType));
Assert.Throws<InvalidOperationException>(() => handlerProvider.GetMethodHandler("UndefinedMethod", TestMethodHandler.RequestType, TestMethodHandler.ResponseType));
}

[Fact]
public void GetRegisteredMethods_GetRequiredServices()
public void GetMethodHandler_WrongResponseType_Throws()
{
var handlerProvider = GetHandlerProvider(supportsGetRegisteredServices: false);

var registeredMethods = handlerProvider.GetRegisteredMethods();

Assert.Collection(registeredMethods,
(r) => Assert.Equal(_method, r.MethodName));
Assert.Throws<InvalidOperationException>(() => handlerProvider.GetMethodHandler(TestMethodHandler.Name, TestMethodHandler.RequestType, responseType: typeof(long)));
}

[Fact]
public void GetRegisteredMethods_GetRegisteredServices()
[Theory]
[CombinatorialData]
public void GetRegisteredMethods(bool supportsGetRegisteredServices)
{
var handlerProvider = GetHandlerProvider(supportsGetRegisteredServices: true);
var handlerProvider = GetHandlerProvider(supportsGetRegisteredServices);

var registeredMethods = handlerProvider.GetRegisteredMethods();
var registeredMethods = handlerProvider.GetRegisteredMethods().OrderBy(m => m.MethodName);

Assert.Collection(registeredMethods,
(r) => Assert.Equal(_method, r.MethodName));
r => Assert.Equal(TestMethodHandler.Name, r.MethodName),
r => Assert.Equal(TestNotificationHandler.Name, r.MethodName),
r => Assert.Equal(TestParameterlessMethodHandler.Name, r.MethodName),
r => Assert.Equal(TestParameterlessNotificationHandler.Name, r.MethodName));
}

private static HandlerProvider GetHandlerProvider(bool supportsGetRegisteredServices)
{
var lspServices = GetLspServices(supportsGetRegisteredServices);
var handler = new HandlerProvider(lspServices);

return handler;
}
=> new(GetLspServices(supportsGetRegisteredServices));

private static ILspServices GetLspServices(bool supportsGetRegisteredServices)
private static TestLspServices GetLspServices(bool supportsGetRegisteredServices)
{
var services = new List<(Type, object)> { (typeof(IMethodHandler), _expectedMethodHandler) };
var lspServices = new TestLspServices(services, supportsGetRegisteredServices);
return lspServices;
}

[LanguageServerEndpoint(_method)]
internal class TestMethodHandler : IRequestHandler<int, string, TestRequestContext>
{
public bool MutatesSolutionState => true;

public static string Method = _method;

public static Type RequestType = typeof(int);

public static Type ResponseType = typeof(string);

public Task<string> HandleRequestAsync(int request, TestRequestContext context, CancellationToken cancellationToken)
var services = new List<(Type, object)>
{
return Task.FromResult("stuff");
}
}
(typeof(IMethodHandler), TestMethodHandler.Instance),
(typeof(IMethodHandler), TestNotificationHandler.Instance),
(typeof(IMethodHandler), TestParameterlessMethodHandler.Instance),
(typeof(IMethodHandler), TestParameterlessNotificationHandler.Instance),
};

private class TestMethodHandlerWithoutAttribute : INotificationHandler<TestRequestContext>
{
public bool MutatesSolutionState => true;

public Task HandleNotificationAsync(TestRequestContext requestContext, CancellationToken cancellationToken)
{
throw new NotImplementedException();
}
return new TestLspServices(services, supportsGetRegisteredServices);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,12 @@
using System;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.CommonLanguageServerProtocol.Framework;

namespace Microsoft.CommonLanguageServerProtocol.Framework.UnitTests;

public class NoOpLspLogger : ILspLogger
{
public static NoOpLspLogger Instance = new NoOpLspLogger();
public static NoOpLspLogger Instance = new();

public void LogError(string message, params object[] @params)
{
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using System;
using System.Collections.Generic;
using System.Collections.Immutable;
using System.Linq;

namespace Microsoft.CommonLanguageServerProtocol.Framework.UnitTests;

internal class TestHandlerProvider : IHandlerProvider
{
private readonly IEnumerable<(RequestHandlerMetadata metadata, IMethodHandler provider)> _providers;

public TestHandlerProvider(IEnumerable<(RequestHandlerMetadata metadata, IMethodHandler provider)> providers)
=> _providers = providers;

public IMethodHandler GetMethodHandler(string method, Type? requestType, Type? responseType)
=> _providers.Single(p => p.metadata.MethodName == method).provider;

public ImmutableArray<RequestHandlerMetadata> GetRegisteredMethods()
=> _providers.Select(p => p.metadata).ToImmutableArray();
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using System;
using System.Collections.Generic;
using System.Collections.Immutable;
using System.Linq;

namespace Microsoft.CommonLanguageServerProtocol.Framework.UnitTests;

internal class TestLspServices : ILspServices
{
private readonly bool _supportsGetRegisteredServices;
private readonly IEnumerable<(Type type, object instance)> _services;

public TestLspServices(IEnumerable<(Type type, object instance)> services, bool supportsGetRegisteredServices)
{
_services = services;
_supportsGetRegisteredServices = supportsGetRegisteredServices;
}

public void Dispose()
{
}

public ImmutableArray<Type> GetRegisteredServices()
=> _services.Select(s => s.instance.GetType()).ToImmutableArray();

public T GetRequiredService<T>() where T : notnull
=> (T?)TryGetService(typeof(T)) ?? throw new InvalidOperationException($"{typeof(T).Name} did not have a service");

public IEnumerable<T> GetRequiredServices<T>()
=> _supportsGetRegisteredServices ? Array.Empty<T>() : _services.Where(s => s.instance is T).Select(s => (T)s.instance);

public bool SupportsGetRegisteredServices()
=> _supportsGetRegisteredServices;

public object? TryGetService(Type type)
=> _services.FirstOrDefault(s => (_supportsGetRegisteredServices ? s.instance.GetType() : s.type) == type).instance;
}
Loading

0 comments on commit f731f97

Please sign in to comment.