Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix up implementation for parameterless requests. #68077

Merged
merged 2 commits into from
May 4, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -93,17 +93,6 @@ protected virtual IHandlerProvider GetHandlerProvider()

protected virtual void SetupRequestDispatcher(IHandlerProvider handlerProvider)
{
var entryPointMethod = typeof(DelegatingEntryPoint).GetMethod(nameof(DelegatingEntryPoint.EntryPointAsync));
if (entryPointMethod is null)
throw new InvalidOperationException($"{typeof(DelegatingEntryPoint).FullName} is missing method {nameof(DelegatingEntryPoint.EntryPointAsync)}");
var notificationMethod = typeof(DelegatingEntryPoint).GetMethod(nameof(DelegatingEntryPoint.NotificationEntryPointAsync));
if (notificationMethod is null)
throw new InvalidOperationException($"{typeof(DelegatingEntryPoint).FullName} is missing method {nameof(DelegatingEntryPoint.NotificationEntryPointAsync)}");

var parameterlessNotificationMethod = typeof(DelegatingEntryPoint).GetMethod(nameof(DelegatingEntryPoint.ParameterlessNotificationEntryPointAsync));
if (parameterlessNotificationMethod is null)
throw new InvalidOperationException($"{typeof(DelegatingEntryPoint).FullName} is missing method {nameof(DelegatingEntryPoint.ParameterlessNotificationEntryPointAsync)}");

foreach (var metadata in handlerProvider.GetRegisteredMethods())
{
// Instead of concretely defining methods for each LSP method, we instead dynamically construct the
Expand All @@ -113,31 +102,17 @@ protected virtual void SetupRequestDispatcher(IHandlerProvider handlerProvider)
//
// We also do not use the StreamJsonRpc support for JToken as the rpc method parameters because we want
// StreamJsonRpc to do the deserialization to handle streaming requests using IProgress<T>.

var method = DelegatingEntryPoint.GetMethodInstantiation(metadata.RequestType, metadata.ResponseType);

var delegatingEntryPoint = new DelegatingEntryPoint(metadata.MethodName, this);

MethodInfo genericEntryPointMethod;
if (metadata.RequestType is not null && metadata.ResponseType is not null)
{
genericEntryPointMethod = entryPointMethod.MakeGenericMethod(metadata.RequestType, metadata.ResponseType);
}
else if (metadata.RequestType is not null && metadata.ResponseType is null)
{
genericEntryPointMethod = notificationMethod.MakeGenericMethod(metadata.RequestType);
}
else if (metadata.RequestType is null && metadata.ResponseType is null)
{
// No need to genericize
genericEntryPointMethod = parameterlessNotificationMethod;
}
else
{
throw new NotImplementedException($"An unrecognized {nameof(RequestHandlerMetadata)} situation has occured");
}
var methodAttribute = new JsonRpcMethodAttribute(metadata.MethodName)
{
UseSingleObjectParameterDeserialization = true,
};
_jsonRpc.AddLocalRpcMethod(genericEntryPointMethod, delegatingEntryPoint, methodAttribute);

_jsonRpc.AddLocalRpcMethod(method, delegatingEntryPoint, methodAttribute);
}
}

Expand Down Expand Up @@ -171,31 +146,45 @@ protected IRequestExecutionQueue<TRequestContext> GetRequestExecutionQueue()
/// Wrapper class to hold the method and properties from the <see cref="AbstractLanguageServer{RequestContextType}"/>
/// that the method info passed to StreamJsonRpc is created from.
/// </summary>
private class DelegatingEntryPoint
private sealed class DelegatingEntryPoint
{
private readonly string _method;
private readonly AbstractLanguageServer<TRequestContext> _target;

private static readonly MethodInfo s_entryPointMethod = typeof(DelegatingEntryPoint).GetMethod(nameof(EntryPointAsync));
private static readonly MethodInfo s_parameterlessEntryPointMethod = typeof(DelegatingEntryPoint).GetMethod(nameof(ParameterlessEntryPointAsync));
private static readonly MethodInfo s_notificationMethod = typeof(DelegatingEntryPoint).GetMethod(nameof(NotificationEntryPointAsync));
private static readonly MethodInfo s_parameterlessNotificationMethod = typeof(DelegatingEntryPoint).GetMethod(nameof(ParameterlessNotificationEntryPointAsync));

public DelegatingEntryPoint(string method, AbstractLanguageServer<TRequestContext> target)
{
_method = method;
_target = target;
}

public static MethodInfo GetMethodInstantiation(Type? requestType, Type? responseType)
=> (requestType, responseType) switch
{
(requestType: not null, responseType: not null) => s_entryPointMethod.MakeGenericMethod(requestType, responseType),
(requestType: null, responseType: not null) => s_parameterlessEntryPointMethod.MakeGenericMethod(responseType),
(requestType: not null, responseType: null) => s_notificationMethod.MakeGenericMethod(requestType),
(requestType: null, responseType: null) => s_parameterlessNotificationMethod,
};

public async Task NotificationEntryPointAsync<TRequest>(TRequest request, CancellationToken cancellationToken) where TRequest : class
{
var queue = _target.GetRequestExecutionQueue();
var lspServices = _target.GetLspServices();

_ = await queue.ExecuteAsync<TRequest, VoidReturn>(request, _method, lspServices, cancellationToken).ConfigureAwait(false);
_ = await queue.ExecuteAsync<TRequest, NoValue>(request, _method, lspServices, cancellationToken).ConfigureAwait(false);
}

public async Task ParameterlessNotificationEntryPointAsync(CancellationToken cancellationToken)
{
var queue = _target.GetRequestExecutionQueue();
var lspServices = _target.GetLspServices();

_ = await queue.ExecuteAsync<VoidReturn, VoidReturn>(VoidReturn.Instance, _method, lspServices, cancellationToken).ConfigureAwait(false);
_ = await queue.ExecuteAsync<NoValue, NoValue>(NoValue.Instance, _method, lspServices, cancellationToken).ConfigureAwait(false);
}

public async Task<TResponse?> EntryPointAsync<TRequest, TResponse>(TRequest request, CancellationToken cancellationToken) where TRequest : class
Expand All @@ -207,6 +196,16 @@ public async Task ParameterlessNotificationEntryPointAsync(CancellationToken can

return result;
}

public async Task<TResponse?> ParameterlessEntryPointAsync<TResponse>(CancellationToken cancellationToken)
{
var queue = _target.GetRequestExecutionQueue();
var lspServices = _target.GetLspServices();

var result = await queue.ExecuteAsync<NoValue, TResponse>(NoValue.Instance, _method, lspServices, cancellationToken).ConfigureAwait(false);

return result;
}
}

public Task WaitForExitAsync()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
using System;
using System.Collections.Generic;
using System.Collections.Immutable;
using System.Diagnostics.Contracts;
using System.Diagnostics;
using System.Linq;
using System.Reflection;
using System.Threading;
Expand Down Expand Up @@ -202,59 +202,51 @@ private record HandlerTypes(Type? RequestType, Type? ResponseType, Type RequestC
/// </summary>
private static List<HandlerTypes> ConvertHandlerTypeToRequestResponseTypes(Type handlerType)
{
var genericInterfaces = handlerType.GetInterfaces().Where(i => i.IsGenericType);
var requestHandlerGenericTypes = GetGenericTypes(genericInterfaces, typeof(IRequestHandler<,,>));
var parameterlessNotificationHandlerGenericTypes = GetGenericTypes(genericInterfaces, typeof(INotificationHandler<>));
var notificationHandlerGenericTypes = GetGenericTypes(genericInterfaces, typeof(INotificationHandler<,>));

var handlerList = new List<HandlerTypes>();

foreach (var requestHandlerGenericType in requestHandlerGenericTypes)
foreach (var interfaceType in handlerType.GetInterfaces())
{
var genericArguments = requestHandlerGenericType.GetGenericArguments();

if (genericArguments.Length != 3)
if (!interfaceType.IsGenericType)
{
throw new InvalidOperationException($"Provided handler type {handlerType.FullName} does not have exactly three generic arguments");
continue;
}

handlerList.Add(new HandlerTypes(RequestType: genericArguments[0], ResponseType: genericArguments[1], RequestContext: genericArguments[2]));
}

foreach (var parameterlessNotificationHandlerGenericType in parameterlessNotificationHandlerGenericTypes)
{
var genericArguments = parameterlessNotificationHandlerGenericType.GetGenericArguments();
var genericDefinition = interfaceType.GetGenericTypeDefinition();

if (genericArguments.Length != 1)
HandlerTypes types;
if (genericDefinition == typeof(IRequestHandler<,,>))
{
throw new InvalidOperationException($"Provided handler type {handlerType.FullName} does not have exactly 1 generic argument");
var genericArguments = interfaceType.GetGenericArguments();
types = new HandlerTypes(RequestType: genericArguments[0], ResponseType: genericArguments[1], RequestContext: genericArguments[2]);
}

handlerList.Add(new HandlerTypes(RequestType: null, ResponseType: null, RequestContext: genericArguments[0]));
}

foreach (var notificationHandlerGenericType in notificationHandlerGenericTypes)
{
var genericArguments = notificationHandlerGenericType.GetGenericArguments();

if (genericArguments.Length != 2)
else if (genericDefinition == typeof(IRequestHandler<,>))
{
throw new InvalidOperationException($"Provided handler type {handlerType.FullName} does not have exactly 2 generic arguments");
var genericArguments = interfaceType.GetGenericArguments();
types = new HandlerTypes(RequestType: null, ResponseType: genericArguments[0], RequestContext: genericArguments[1]);
}
else if (genericDefinition == typeof(INotificationHandler<,>))
{
var genericArguments = interfaceType.GetGenericArguments();
types = new HandlerTypes(RequestType: genericArguments[0], ResponseType: null, RequestContext: genericArguments[1]);
}
else if (genericDefinition == typeof(INotificationHandler<>))
{
var genericArguments = interfaceType.GetGenericArguments();
types = new HandlerTypes(RequestType: null, ResponseType: null, RequestContext: genericArguments[0]);
}
else
{
continue;
}

handlerList.Add(new HandlerTypes(RequestType: genericArguments[0], ResponseType: null, RequestContext: genericArguments[1]));
handlerList.Add(types);
}

if (!handlerList.Any())
if (handlerList.Count == 0)
{
throw new InvalidOperationException($"Provided handler type {handlerType.FullName} does not implement {typeof(IRequestHandler<,,>).Name}, {typeof(INotificationHandler<>).Name} or {typeof(INotificationHandler<,>).Name}");
throw new InvalidOperationException($"Provided handler type {handlerType.FullName} does not implement {nameof(IMethodHandler)}");
}

return handlerList;

static IEnumerable<Type> GetGenericTypes(IEnumerable<Type> genericInterfaces, Type methodHandlerType)
{
return genericInterfaces.Where(i => i.GetGenericTypeDefinition() == methodHandlerType);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,11 @@
namespace Microsoft.CommonLanguageServerProtocol.Framework;

/// <summary>
/// A placeholder type to help handle Notification messages.
/// A placeholder type to help handle parameterless messages and messages with no return value.
/// </summary>
internal record VoidReturn
internal sealed class NoValue
{
public static VoidReturn Instance = new();
public static NoValue Instance = new();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I couldn't find anyone else using this, so this might not need a dual insertion

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's internal, so shouldn't be used, no?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe everyone using this has IVTs

}

internal class QueueItem<TRequest, TResponse, TRequestContext> : IQueueItem<TRequestContext>
Expand Down Expand Up @@ -119,32 +119,38 @@ public async Task StartRequestAsync(TRequestContext? context, CancellationToken
_logger.LogWarning($"Could not get request context for {MethodName}");
_completionSource.TrySetException(new InvalidOperationException($"Unable to create request context for {MethodName}"));
}
else if (_handler is IRequestHandler<TRequest, TResponse, TRequestContext> requestHandler)
{
var result = await requestHandler.HandleRequestAsync(_request, context, cancellationToken).ConfigureAwait(false);

_completionSource.TrySetResult(result);
}
else if (_handler is IRequestHandler<TResponse, TRequestContext> parameterlessRequestHandler)
{
var result = await parameterlessRequestHandler.HandleRequestAsync(context, cancellationToken).ConfigureAwait(false);

_completionSource.TrySetResult(result);
}
else if (_handler is INotificationHandler<TRequest, TRequestContext> notificationHandler)
{
await notificationHandler.HandleNotificationAsync(_request, context, cancellationToken).ConfigureAwait(false);

// We know that the return type of <see cref="INotificationHandler{TRequestType, RequestContextType}"/> will always be <see cref="VoidReturn" /> even if the compiler doesn't.
_completionSource.TrySetResult((TResponse)(object)NoValue.Instance);
}
else if (_handler is INotificationHandler<TRequestContext> parameterlessNotificationHandler)
{
await parameterlessNotificationHandler.HandleNotificationAsync(context, cancellationToken).ConfigureAwait(false);

// We know that the return type of <see cref="INotificationHandler{TRequestType, RequestContextType}"/> will always be <see cref="VoidReturn" /> even if the compiler doesn't.
_completionSource.TrySetResult((TResponse)(object)NoValue.Instance);
}
else
{
if (_handler is IRequestHandler<TRequest, TResponse, TRequestContext> requestHandler)
{
var result = await requestHandler.HandleRequestAsync(_request, context, cancellationToken).ConfigureAwait(false);

_completionSource.TrySetResult(result);
}
else if (_handler is INotificationHandler<TRequest, TRequestContext> notificationHandler)
{
await notificationHandler.HandleNotificationAsync(_request, context, cancellationToken).ConfigureAwait(false);

// We know that the return type of <see cref="INotificationHandler{TRequestType, RequestContextType}"/> will always be <see cref="VoidReturn" /> even if the compiler doesn't.
_completionSource.TrySetResult((TResponse)(object)VoidReturn.Instance);
}
else if (_handler is INotificationHandler<TRequestContext> parameterlessNotificationHandler)
{
await parameterlessNotificationHandler.HandleNotificationAsync(context, cancellationToken).ConfigureAwait(false);

// We know that the return type of <see cref="INotificationHandler{TRequestType, RequestContextType}"/> will always be <see cref="VoidReturn" /> even if the compiler doesn't.
_completionSource.TrySetResult((TResponse)(object)VoidReturn.Instance);
}
else
{
throw new NotImplementedException($"Unrecognized {nameof(IMethodHandler)} implementation {_handler.GetType().Name}");
}
throw new NotImplementedException(
$"Unrecognized {nameof(IMethodHandler)} implementation {_handler.GetType()}. " +
$"TRequest is {typeof(TRequest)}. " +
$"TResponse is {typeof(TResponse)}.");
}
}
catch (OperationCanceledException ex)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,8 @@ public void Start()

protected IMethodHandler GetMethodHandler<TRequest, TResponse>(string methodName)
{
var requestType = typeof(TRequest) == typeof(VoidReturn) ? null : typeof(TRequest);
var responseType = typeof(TResponse) == typeof(VoidReturn) ? null : typeof(TResponse);
var requestType = typeof(TRequest) == typeof(NoValue) ? null : typeof(TRequest);
var responseType = typeof(TResponse) == typeof(NoValue) ? null : typeof(TResponse);

var handler = _handlerProvider.GetMethodHandler(methodName, requestType, responseType);

Expand Down