Skip to content

Commit

Permalink
fix: Move HeaderInterceptor to use our Middleware API
Browse files Browse the repository at this point in the history
Prior to this commit, we had a gRPC Interceptor that was
responsible for setting headers (cache name, auth token, etc.)
that were sent to the server.

We attempted to implement our Middleware layer using gRPC interceptors
but because their API is not asynchronous, there was no way to
accomplish the asynchronous tasks that we needed to accomplish,
so we had to introduce a bespoke layer for our Middleware that uses
the same patterns and some of the same constructs as the gRPC
Interceptors do, but which supported an asynchronous API.

This left us with a mix of both Interceptors (for the Header Interceptor)
and Middlewares.  This proved to be problematic when attempting to
implement request retries, because we could not control the order that
the interceptor ran in relative to the middlewares.  Thus whenever
we were executing our retry logic, the header interceptor would be
re-run and it would add a second copy of all of the headers, which
made the request invalid.

It's important that we be able to exercise full control over the
order that all of the middlewares are executed, so in this PR
we migrate the Header Interceptor over to be a Middleware instead.
This allows us to position it in the correct spot in the ordered
list of Middlewares.

Because the control plane client was also using the Headers interceptor,
this commit also adds Middleware support for the control plane client
and migrates it from the old Headers Interceptor to the new Headers
Middleware.
  • Loading branch information
cprice404 committed Oct 7, 2022
1 parent 9214bb4 commit 963378f
Show file tree
Hide file tree
Showing 5 changed files with 169 additions and 96 deletions.
89 changes: 86 additions & 3 deletions src/Momento.Sdk/Internal/ControlGrpcManager.cs
Original file line number Diff line number Diff line change
@@ -1,19 +1,95 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Threading.Tasks;
using Grpc.Core;
using Grpc.Core.Interceptors;
using Grpc.Net.Client;
using Microsoft.Extensions.Logging;
using Momento.Protos.CacheClient;
using Momento.Protos.ControlClient;
using Momento.Sdk.Config.Middleware;
using Momento.Sdk.Internal;
using Momento.Sdk.Internal.Middleware;
using static System.Reflection.Assembly;

namespace Momento.Sdk.Internal;


public interface IControlClient
{
public Task<_CreateCacheResponse> CreateCacheAsync(_CreateCacheRequest request, CallOptions callOptions);
public Task<_DeleteCacheResponse> DeleteCacheAsync(_DeleteCacheRequest request, CallOptions callOptions);
public Task<_ListCachesResponse> ListCachesAsync(_ListCachesRequest request, CallOptions callOptions);
}


// Ideally we would implement our middleware based on gRPC Interceptors. Unfortunately,
// the their method signatures are not asynchronous. Thus, for any middleware that may
// require asynchronous actions (such as our MaxConcurrentRequestsMiddleware), we would
// end up blocking threads to wait for the completion of the async task, which would have
// a big negative impact on performance. Instead, in this commit, we implement a thin
// middleware layer of our own that uses asynchronous signatures throughout. This has
// the nice side effect of making the user-facing API for writing Middlewares a bit less
// of a learning curve for anyone not super deep on gRPC internals.
internal class ControlClientWithMiddleware : IControlClient
{
private readonly IList<IMiddleware> _middlewares;
private readonly ScsControl.ScsControlClient _generatedClient;

public ControlClientWithMiddleware(ScsControl.ScsControlClient generatedClient, IList<IMiddleware> middlewares)
{
_generatedClient = generatedClient;
_middlewares = middlewares;
}

public async Task<_CreateCacheResponse> CreateCacheAsync(_CreateCacheRequest request, CallOptions callOptions)
{
var wrapped = await WrapWithMiddleware(request, callOptions, (r, o) => _generatedClient.CreateCacheAsync(r, o));
return await wrapped.ResponseAsync;
}

public async Task<_DeleteCacheResponse> DeleteCacheAsync(_DeleteCacheRequest request, CallOptions callOptions)
{
var wrapped = await WrapWithMiddleware(request, callOptions, (r, o) => _generatedClient.DeleteCacheAsync(r, o));
return await wrapped.ResponseAsync;
}

public async Task<_ListCachesResponse> ListCachesAsync(_ListCachesRequest request, CallOptions callOptions)
{
var wrapped = await WrapWithMiddleware(request, callOptions, (r, o) => _generatedClient.ListCachesAsync(r, o));
return await wrapped.ResponseAsync;
}

private async Task<MiddlewareResponseState<TResponse>> WrapWithMiddleware<TRequest, TResponse>(
TRequest request,
CallOptions callOptions,
Func<TRequest, CallOptions, AsyncUnaryCall<TResponse>> continuation
)
{
Func<TRequest, CallOptions, Task<MiddlewareResponseState<TResponse>>> continuationWithMiddlewareResponseState = (r, o) =>
{
var result = continuation(r, o);
return Task.FromResult(new MiddlewareResponseState<TResponse>(
ResponseAsync: result.ResponseAsync,
ResponseHeadersAsync: result.ResponseHeadersAsync,
GetStatus: result.GetStatus,
GetTrailers: result.GetTrailers
));
};

var wrapped = _middlewares.Aggregate(continuationWithMiddlewareResponseState, (acc, middleware) =>
{
return (r, o) => middleware.WrapRequest(r, o, acc);
});
return await wrapped(request, callOptions);
}
}

internal sealed class ControlGrpcManager : IDisposable
{
private readonly GrpcChannel channel;
public ScsControl.ScsControlClient Client { get; }
public IControlClient Client { get; }
private readonly string version = "dotnet:" + GetAssembly(typeof(Momento.Sdk.Responses.CacheGetResponse)).GetName().Version.ToString();
// Some System.Environment.Version remarks to be aware of
// https://learn.microsoft.com/en-us/dotnet/api/system.environment.version?view=netstandard-2.0#remarks
Expand All @@ -25,8 +101,15 @@ public ControlGrpcManager(string authToken, string endpoint, ILoggerFactory logg
var uri = $"https://{endpoint}";
this.channel = GrpcChannel.ForAddress(uri, new GrpcChannelOptions() { Credentials = ChannelCredentials.SecureSsl });
List<Header> headers = new List<Header> { new Header(name: Header.AuthorizationKey, value: authToken), new Header(name: Header.AgentKey, value: version), new Header(name: Header.RuntimeVersionKey, value: runtimeVersion) };
CallInvoker invoker = channel.Intercept(new HeaderInterceptor(headers));
Client = new ScsControl.ScsControlClient(invoker);
CallInvoker invoker = this.channel.CreateCallInvoker();

var middlewares = new List<IMiddleware> {
new HeaderMiddleware(headers)
};


Client = new ControlClientWithMiddleware(new ScsControl.ScsControlClient(invoker), middlewares);

this._logger = loggerFactory.CreateLogger<ControlGrpcManager>();
}

Expand Down
7 changes: 5 additions & 2 deletions src/Momento.Sdk/Internal/DataGrpcManager.cs
Original file line number Diff line number Diff line change
Expand Up @@ -215,12 +215,15 @@ internal DataGrpcManager(IConfiguration config, string authToken, string host)

this.channel = GrpcChannel.ForAddress(url, channelOptions);
List<Header> headers = new List<Header> { new Header(name: Header.AuthorizationKey, value: authToken), new Header(name: Header.AgentKey, value: version), new Header(name: Header.RuntimeVersionKey, value: runtimeVersion) };

this._logger = config.LoggerFactory.CreateLogger<DataGrpcManager>();
CallInvoker invoker = this.channel.Intercept(new HeaderInterceptor(headers));

CallInvoker invoker = this.channel.CreateCallInvoker();

var middlewares = config.Middlewares.Concat(
new List<IMiddleware> {
new MaxConcurrentRequestsMiddleware(config.LoggerFactory, config.TransportStrategy.MaxConcurrentRequests)
new MaxConcurrentRequestsMiddleware(config.TransportStrategy.MaxConcurrentRequests),
new HeaderMiddleware(headers)
}
).ToList();

Expand Down
88 changes: 0 additions & 88 deletions src/Momento.Sdk/Internal/HeaderInterceptor.cs

This file was deleted.

75 changes: 75 additions & 0 deletions src/Momento.Sdk/Internal/Middleware/HeaderMiddleware.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Threading.Tasks;
using Grpc.Core;
using Grpc.Core.Interceptors;
using Momento.Sdk.Config.Middleware;

namespace Momento.Sdk.Internal.Middleware
{
class Header
{
public const string AuthorizationKey = "Authorization";
public const string AgentKey = "Agent";
public const string RuntimeVersionKey = "Runtime_Version";
public readonly List<string> onceOnlyHeaders = new List<string> { Header.AgentKey, Header.RuntimeVersionKey };
public string Name;
public string Value;
public Header(String name, String value)
{
this.Name = name;
this.Value = value;
}
}

class HeaderMiddleware : IMiddleware
{
private readonly List<Header> headersToAddEveryTime = new List<Header> { };
private readonly List<Header> headersToAddOnce = new List<Header> { };
private volatile Boolean areOnlyOnceHeadersSent = false;
public HeaderMiddleware(List<Header> headers)
{
this.headersToAddOnce = headers.Where(header => header.onceOnlyHeaders.Contains(header.Name)).ToList();
this.headersToAddEveryTime = headers.Where(header => !header.onceOnlyHeaders.Contains(header.Name)).ToList();
}

public async Task<MiddlewareResponseState<TResponse>> WrapRequest<TRequest, TResponse>(
TRequest request,
CallOptions callOptions,
Func<TRequest, CallOptions, Task<MiddlewareResponseState<TResponse>>> continuation
)
{
var callOptionsWithHeaders = callOptions;
if (callOptionsWithHeaders.Headers == null)
{
callOptionsWithHeaders = callOptionsWithHeaders.WithHeaders(new Metadata());
}

var headers = callOptionsWithHeaders.Headers!;

foreach (Header header in this.headersToAddEveryTime)
{
headers.Add(header.Name, header.Value);
}
if (!areOnlyOnceHeadersSent)
{
foreach (Header header in this.headersToAddOnce)
{
headers.Add(header.Name, header.Value);
}
areOnlyOnceHeadersSent = true;
}


var nextState = await continuation(request, callOptionsWithHeaders);
return new MiddlewareResponseState<TResponse>(
ResponseAsync: nextState.ResponseAsync,
ResponseHeadersAsync: nextState.ResponseHeadersAsync,
GetStatus: nextState.GetStatus,
GetTrailers: nextState.GetTrailers
);
}
}
}

6 changes: 3 additions & 3 deletions src/Momento.Sdk/Internal/ScsControlClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ public async Task<CreateCacheResponse> CreateCacheAsync(string cacheName)
{
CheckValidCacheName(cacheName);
_CreateCacheRequest request = new _CreateCacheRequest() { CacheName = cacheName };
await this.grpcManager.Client.CreateCacheAsync(request, deadline: CalculateDeadline());
await this.grpcManager.Client.CreateCacheAsync(request, new CallOptions(deadline: CalculateDeadline()));
return new CreateCacheResponse.Success();
}
catch (Exception e)
Expand All @@ -49,7 +49,7 @@ public async Task<DeleteCacheResponse> DeleteCacheAsync(string cacheName)
{
CheckValidCacheName(cacheName);
_DeleteCacheRequest request = new _DeleteCacheRequest() { CacheName = cacheName };
await this.grpcManager.Client.DeleteCacheAsync(request, deadline: CalculateDeadline());
await this.grpcManager.Client.DeleteCacheAsync(request, new CallOptions(deadline: CalculateDeadline()));
return new DeleteCacheResponse.Success();
}
catch (Exception e)
Expand All @@ -63,7 +63,7 @@ public async Task<ListCachesResponse> ListCachesAsync(string? nextPageToken = nu
_ListCachesRequest request = new _ListCachesRequest() { NextToken = nextPageToken == null ? "" : nextPageToken };
try
{
_ListCachesResponse result = await this.grpcManager.Client.ListCachesAsync(request, deadline: CalculateDeadline());
_ListCachesResponse result = await this.grpcManager.Client.ListCachesAsync(request, new CallOptions(deadline: CalculateDeadline()));
return new ListCachesResponse.Success(result);
}
catch (Exception e)
Expand Down

0 comments on commit 963378f

Please sign in to comment.