Skip to content

Commit

Permalink
implement apiversion parameter (#4527)
Browse files Browse the repository at this point in the history
Fix #4206

- When the service is versioned, the client will has apiVersion field. 
- Each operation which has apiVersion parameter should use the client
apiVersion parameter directly and the operation signature should not
include the apiVersion parameter.
  • Loading branch information
chunyu3 authored Oct 10, 2024
1 parent 9121ed4 commit bd95294
Show file tree
Hide file tree
Showing 9 changed files with 252 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ public ClientOptionsProvider(InputClient inputClient, ClientProvider clientProvi
}
}

internal PropertyProvider? VersionProperty => _versionProperty;
private TypeProvider? ServiceVersionEnum => _serviceVersionEnum?.Value;
private FieldProvider? LatestVersionField => _latestVersionField ??= BuildLatestVersionField();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,13 @@ public class ClientProvider : TypeProvider
private readonly FieldProvider? _apiKeyAuthField;
private readonly FieldProvider? _authorizationHeaderConstant;
private readonly FieldProvider? _authorizationApiKeyPrefixConstant;
private FieldProvider? _apiVersionField;
private readonly ParameterProvider[] _subClientInternalConstructorParams;
private IReadOnlyList<Lazy<ClientProvider>>? _subClients;
private ParameterProvider? _clientOptionsParameter;
private ClientOptionsProvider? _clientOptions;
private RestClientProvider? _restClient;
private readonly InputParameter[] _allClientParameters;

private ParameterProvider? ClientOptionsParameter => _clientOptionsParameter ??= ClientOptions != null
? ScmKnownParameters.ClientOptions(ClientOptions.Type)
Expand Down Expand Up @@ -105,6 +107,8 @@ public ClientProvider(InputClient inputClient)
}

_endpointParameterName = new(GetEndpointParameterName);

_allClientParameters = _inputClient.Parameters.Concat(_inputClient.Operations.SelectMany(op => op.Parameters).Where(p => p.Kind == InputOperationParameterKind.Client)).DistinctBy(p => p.Name).ToArray();
}

private List<ParameterProvider>? _uriParameters;
Expand Down Expand Up @@ -160,19 +164,23 @@ protected override FieldProvider[] BuildFields()
}
}

// Add optional client parameters as fields
foreach (var p in _inputClient.Parameters)
foreach (var p in _allClientParameters)
{
if (!p.IsEndpoint)
{
var type = ClientModelPlugin.Instance.TypeFactory.CreateCSharpType(p.Type);
if (type != null)
{
fields.Add(new(
FieldProvider field = new(
FieldModifiers.Private | FieldModifiers.ReadOnly,
type,
"_" + p.Name.ToVariableName(),
this));
this);
if (p.IsApiVersion)
{
_apiVersionField = field;
}
fields.Add(field);
}
}
}
Expand Down Expand Up @@ -241,9 +249,9 @@ private IReadOnlyList<ParameterProvider> GetRequiredParameters()
_uriParameters = [];

ParameterProvider? currentParam = null;
foreach (var parameter in _inputClient.Parameters)
foreach (var parameter in _allClientParameters)
{
if (parameter.IsRequired && !parameter.IsEndpoint)
if (parameter.IsRequired && !parameter.IsEndpoint && !parameter.IsApiVersion)
{
currentParam = ClientModelPlugin.Instance.TypeFactory.CreateParameter(parameter);
currentParam.Field = Fields.FirstOrDefault(f => f.Name == "_" + parameter.Name);
Expand Down Expand Up @@ -304,10 +312,16 @@ private MethodBodyStatement[] BuildPrimaryConstructorBody(IReadOnlyList<Paramete
{
if (f != _apiKeyAuthField
&& f != EndpointField
&& !f.Modifiers.HasFlag(FieldModifiers.Const)
&& clientOptionsPropertyDict.TryGetValue(f.Name.ToCleanName(), out var optionsProperty))
&& !f.Modifiers.HasFlag(FieldModifiers.Const))
{
body.Add(f.Assign(ClientOptionsParameter.Property(optionsProperty.Name)).Terminate());
if (f == _apiVersionField && ClientOptions.VersionProperty != null)
{
body.Add(f.Assign(ClientOptionsParameter.Property(ClientOptions.VersionProperty.Name)).Terminate());
}
else if (clientOptionsPropertyDict.TryGetValue(f.Name.ToCleanName(), out var optionsProperty))
{
clientOptionsPropertyDict.TryGetValue(f.Name.ToCleanName(), out optionsProperty);
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -140,11 +140,23 @@ private MethodProvider BuildCreateRequestMethod(InputOperation operation)
null,
[.. GetMethodParameters(operation, true), options]);
var paramMap = new Dictionary<string, ParameterProvider>(signature.Parameters.ToDictionary(p => p.Name));

foreach (var param in ClientProvider.GetUriParameters())
{
paramMap[param.Name] = param;
}

/* add client-level parameter.*/
foreach (var inputParam in operation.Parameters)
{
if (inputParam.Kind == InputOperationParameterKind.Client && !paramMap.ContainsKey(inputParam.Name))
{
var param = ClientModelPlugin.Instance.TypeFactory.CreateParameter(inputParam);
param.Field = ClientProvider.Fields.FirstOrDefault(f => f.Name == "_" + inputParam.Name);
paramMap[inputParam.Name] = param;
}
}

var classifier = GetClassifier(operation);

return new MethodProvider(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -400,6 +400,26 @@ public void ValidateClientWithSpread(InputClient inputClient)

}

[Test]
public void TestApiVersionOfClient()
{
var client = InputFactory.Client(TestClientName,
operations: [
InputFactory.Operation("OperationWithApiVersion",
parameters: [InputFactory.Parameter("apiVersion", InputPrimitiveType.String, isRequired: true, location: RequestLocation.Query, kind: InputOperationParameterKind.Client)])
]);
var clientProvider = new ClientProvider(client);
Assert.IsNotNull(clientProvider);

/* verify that the client has apiVersion field */
Assert.IsNotNull(clientProvider.Fields.FirstOrDefault(f => f.Name.Equals("_apiVersion")));

var method = clientProvider.Methods.FirstOrDefault(m => m.Signature.Name.Equals("OperationWithApiVersion"));
Assert.IsNotNull(method);
/* verify that the method does not have apiVersion parameter */
Assert.IsNull(method?.Signature.Parameters.FirstOrDefault(p => p.Name.Equals("apiVersion")));
}

private static InputClient GetEnumQueryParamClient()
=> InputFactory.Client(
TestClientName,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
using Microsoft.Generator.CSharp.Tests.Common;
using NUnit.Framework;
using Microsoft.Generator.CSharp.Snippets;
using Microsoft.Generator.CSharp.Statements;

namespace Microsoft.Generator.CSharp.ClientModel.Tests.Providers.ClientProviders
{
Expand Down Expand Up @@ -195,6 +196,26 @@ public void ValidateClientWithSpecialHeaders()
Assert.AreEqual(Helpers.GetExpectedFromFile(), file.Content);
}

[Test]
public void ValidateClientWithApiVersion()
{
var client = InputFactory.Client("TestClient",
operations: [
InputFactory.Operation("OperationWithApiVersion",
parameters: [InputFactory.Parameter("apiVersion", InputPrimitiveType.String, isRequired: true, location: RequestLocation.Query, kind: InputOperationParameterKind.Client)])
]);
var clientProvider = new ClientProvider(client);
var restClientProvider = new MockClientProvider(client, clientProvider);
var method = restClientProvider.Methods.FirstOrDefault(m => m.Signature.Name == "CreateOperationWithApiVersionRequest");
Assert.IsNotNull(method);
/* verify that there is no apiVersion parameter in method signature. */
Assert.IsNull(method?.Signature.Parameters.FirstOrDefault(p => p.Name.Equals("apiVersion")));
var bodyStatements = method?.BodyStatements as MethodBodyStatements;
Assert.IsNotNull(bodyStatements);
/* verify that it will use client _apiVersion field to append query parameter. */
Assert.IsTrue(bodyStatements!.Statements.Any(s => s.ToDisplayString() == "uri.AppendQuery(\"apiVersion\", _apiVersion, true);\n"));
}

private readonly static InputOperation BasicOperation = InputFactory.Operation(
"CreateMessage",
parameters:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -367,3 +367,9 @@ op stillConvenient(): void;
@head
@convenientAPI(true)
op headAsBoolean(@path id: string): void;

@route("/WithApiVersion")
@doc("Return hi again")
@get
@convenientAPI(true)
op WithApiVersion(@header p1: string, @query apiVersion: string): void;
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,22 @@ internal PipelineMessage CreateHeadAsBooleanRequest(string id, RequestOptions op
return message;
}

internal PipelineMessage CreateWithApiVersionRequest(string p1, RequestOptions options)
{
PipelineMessage message = Pipeline.CreateMessage();
message.ResponseClassifier = PipelineMessageClassifier204;
PipelineRequest request = message.Request;
request.Method = "GET";
ClientUriBuilder uri = new ClientUriBuilder();
uri.Reset(_endpoint);
uri.AppendPath("/WithApiVersion", false);
uri.AppendQuery("apiVersion", _apiVersion, true);
request.Uri = uri.ToUri();
request.Headers.Set("p1", p1);
message.Apply(options);
return message;
}

private class Classifier2xxAnd4xx : PipelineMessageClassifier
{
public override bool TryClassify(PipelineMessage message, out bool isError)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ public partial class UnbrandedTypeSpecClient
private const string AuthorizationHeader = "my-api-key";
/// <summary> A credential used to authenticate to the service. </summary>
private readonly ApiKeyCredential _keyCredential;
private readonly string _apiVersion;

/// <summary> Initializes a new instance of UnbrandedTypeSpecClient for mocking. </summary>
protected UnbrandedTypeSpecClient()
Expand Down Expand Up @@ -48,6 +49,7 @@ public UnbrandedTypeSpecClient(Uri endpoint, ApiKeyCredential keyCredential, Unb
_endpoint = endpoint;
_keyCredential = keyCredential;
Pipeline = ClientPipeline.Create(options, Array.Empty<PipelinePolicy>(), new PipelinePolicy[] { ApiKeyAuthenticationPolicy.CreateHeaderApiKeyPolicy(_keyCredential, AuthorizationHeader) }, Array.Empty<PipelinePolicy>());
_apiVersion = options.Version;
}

/// <summary> The HTTP pipeline for sending and receiving REST requests and responses. </summary>
Expand Down Expand Up @@ -1160,5 +1162,69 @@ public virtual async Task<ClientResult> HeadAsBooleanAsync(string id)

return await HeadAsBooleanAsync(id, null).ConfigureAwait(false);
}

/// <summary>
/// [Protocol Method] Return hi again
/// <list type="bullet">
/// <item>
/// <description> This <see href="https://aka.ms/azsdk/net/protocol-methods">protocol method</see> allows explicit creation of the request and processing of the response for advanced scenarios. </description>
/// </item>
/// </list>
/// </summary>
/// <param name="p1"></param>
/// <param name="options"> The request options, which can override default behaviors of the client pipeline on a per-call basis. </param>
/// <exception cref="ArgumentNullException"> <paramref name="p1"/> is null. </exception>
/// <exception cref="ClientResultException"> Service returned a non-success status code. </exception>
/// <returns> The response returned from the service. </returns>
public virtual ClientResult WithApiVersion(string p1, RequestOptions options)
{
Argument.AssertNotNull(p1, nameof(p1));

using PipelineMessage message = CreateWithApiVersionRequest(p1, options);
return ClientResult.FromResponse(Pipeline.ProcessMessage(message, options));
}

/// <summary>
/// [Protocol Method] Return hi again
/// <list type="bullet">
/// <item>
/// <description> This <see href="https://aka.ms/azsdk/net/protocol-methods">protocol method</see> allows explicit creation of the request and processing of the response for advanced scenarios. </description>
/// </item>
/// </list>
/// </summary>
/// <param name="p1"></param>
/// <param name="options"> The request options, which can override default behaviors of the client pipeline on a per-call basis. </param>
/// <exception cref="ArgumentNullException"> <paramref name="p1"/> is null. </exception>
/// <exception cref="ClientResultException"> Service returned a non-success status code. </exception>
/// <returns> The response returned from the service. </returns>
public virtual async Task<ClientResult> WithApiVersionAsync(string p1, RequestOptions options)
{
Argument.AssertNotNull(p1, nameof(p1));

using PipelineMessage message = CreateWithApiVersionRequest(p1, options);
return ClientResult.FromResponse(await Pipeline.ProcessMessageAsync(message, options).ConfigureAwait(false));
}

/// <summary> Return hi again. </summary>
/// <param name="p1"></param>
/// <exception cref="ArgumentNullException"> <paramref name="p1"/> is null. </exception>
/// <exception cref="ClientResultException"> Service returned a non-success status code. </exception>
public virtual ClientResult WithApiVersion(string p1)
{
Argument.AssertNotNull(p1, nameof(p1));

return WithApiVersion(p1, null);
}

/// <summary> Return hi again. </summary>
/// <param name="p1"></param>
/// <exception cref="ArgumentNullException"> <paramref name="p1"/> is null. </exception>
/// <exception cref="ClientResultException"> Service returned a non-success status code. </exception>
public virtual async Task<ClientResult> WithApiVersionAsync(string p1)
{
Argument.AssertNotNull(p1, nameof(p1));

return await WithApiVersionAsync(p1, null).ConfigureAwait(false);
}
}
}
Loading

0 comments on commit bd95294

Please sign in to comment.