diff --git a/sdk/core/Azure.Core/tests/ManagementPipelineBuilderTests.cs b/sdk/core/Azure.Core/tests/ManagementPipelineBuilderTests.cs new file mode 100644 index 0000000000000..421367feda572 --- /dev/null +++ b/sdk/core/Azure.Core/tests/ManagementPipelineBuilderTests.cs @@ -0,0 +1,56 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System; +using System.Linq; +using System.Reflection; +using System.Threading; +using Azure.Core.Pipeline; +using Azure.Core.TestFramework; +using Azure.ResourceManager.Core; +using NUnit.Framework; + +namespace Azure.Core.Tests.Management +{ + public class ManagementPipelineBuilderTests + { + [TestCase] + public void AddPerCallPolicy() + { + var options = new ArmClientOptions(); + var dummyPolicy = new DummyPolicy(); + options.AddPolicy(dummyPolicy, HttpPipelinePosition.PerCall); + var pipeline = ManagementPipelineBuilder.Build(new MockCredential(), new Uri("http://foo.com"), options); + + var perCallPolicyField = pipeline.GetType().GetField("_pipeline", BindingFlags.Instance | BindingFlags.NonPublic | BindingFlags.GetField); + var policies = (ReadOnlyMemory)perCallPolicyField.GetValue(pipeline); + Assert.IsNotNull(policies.ToArray().FirstOrDefault(p => p.GetType() == typeof(DummyPolicy))); + } + + [TestCase] + public void AddPerCallPolicyViaClient() + { + var options = new ArmClientOptions(); + var dummyPolicy = new DummyPolicy(); + options.AddPolicy(dummyPolicy, HttpPipelinePosition.PerCall); + var client = new ArmClient(Guid.NewGuid().ToString(), new MockCredential(), options); + + var pipelineProperty = client.GetType().GetProperty("Pipeline", BindingFlags.Instance | BindingFlags.NonPublic | BindingFlags.GetProperty); + var pipeline = pipelineProperty.GetValue(client) as HttpPipeline; + + var perCallPolicyField = pipeline.GetType().GetField("_pipeline", BindingFlags.Instance | BindingFlags.NonPublic | BindingFlags.GetField); + var policies = (ReadOnlyMemory)perCallPolicyField.GetValue(pipeline); + Assert.IsNotNull(policies.ToArray().FirstOrDefault(p => p.GetType() == typeof(DummyPolicy))); + } + + private class DummyPolicy : HttpPipelineSynchronousPolicy + { + public int numMsgGot = 0; + + public override void OnReceivedResponse(HttpMessage message) + { + Interlocked.Increment(ref numMsgGot); + } + } + } +} diff --git a/sdk/resourcemanager/Azure.ResourceManager.Core/src/ArmClient.cs b/sdk/resourcemanager/Azure.ResourceManager.Core/src/ArmClient.cs index da14d5e265174..07433568676d4 100644 --- a/sdk/resourcemanager/Azure.ResourceManager.Core/src/ArmClient.cs +++ b/sdk/resourcemanager/Azure.ResourceManager.Core/src/ArmClient.cs @@ -97,7 +97,7 @@ private ArmClient( DefaultSubscription = string.IsNullOrWhiteSpace(defaultSubscriptionId) ? GetDefaultSubscription() : GetSubscriptions().TryGet(defaultSubscriptionId); - ClientOptions.ApiVersions.SetProviderClient(credential, baseUri, DefaultSubscription.Id.SubscriptionId); + ClientOptions.ApiVersions.SetProviderClient(credential, baseUri, defaultSubscriptionId ?? DefaultSubscription.Id.SubscriptionId); } /// diff --git a/sdk/resourcemanager/Azure.ResourceManager.Core/src/ArmClientOptions.cs b/sdk/resourcemanager/Azure.ResourceManager.Core/src/ArmClientOptions.cs index 76d72ad712a1d..62b1ae41464d2 100644 --- a/sdk/resourcemanager/Azure.ResourceManager.Core/src/ArmClientOptions.cs +++ b/sdk/resourcemanager/Azure.ResourceManager.Core/src/ArmClientOptions.cs @@ -1,15 +1,13 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -using Azure.Core; -using Azure.Core.Pipeline; using System; using System.Collections.Concurrent; using System.Collections.Generic; using System.ComponentModel; using System.Reflection; -using System.Linq; -using System.Runtime.CompilerServices; +using Azure.Core; +using Azure.Core.Pipeline; namespace Azure.ResourceManager.Core { @@ -29,7 +27,7 @@ public sealed class ArmClientOptions : ClientOptions /// Initializes a new instance of the class. /// public ArmClientOptions() - : this(LocationData.Default, null) + : this(LocationData.Default) { } @@ -38,64 +36,19 @@ public ArmClientOptions() /// /// The default location to use if can't be inherited from parent. public ArmClientOptions(LocationData defaultLocation) - : this(defaultLocation, null) - { - } - - /// - /// Initializes a new instance of the class. - /// - /// The default location to use if can't be inherited from parent. - /// The client parameters to use in these operations. - /// If is null. - internal ArmClientOptions(LocationData defaultLocation, ArmClientOptions other = null) { if (defaultLocation is null) throw new ArgumentNullException(nameof(defaultLocation)); - // Will go away when moved into core since we will have directly access the policies and transport, so just need to set those - if (!ReferenceEquals(other, null)) - Copy(other); DefaultLocation = defaultLocation; ApiVersions = new ApiVersions(this); } - private ArmClientOptions(LocationData location, IList perCallPolicies, IList perRetryPolicies) - { - if (location is null) - throw new ArgumentNullException(nameof(location)); - - DefaultLocation = location; - PerCallPolicies = new List(); - foreach (var call in perCallPolicies) - { - PerCallPolicies.Add(call); - } - PerRetryPolicies = new List(); - foreach (var retry in perRetryPolicies) - { - PerCallPolicies.Add(retry); - } - ApiVersions = new ApiVersions(this); - } - /// /// Gets the default location to use if can't be inherited from parent. /// public LocationData DefaultLocation { get; } - /// - /// Gets each http call policies. - /// - /// A collection of http pipeline policy that may take multiple service requests to iterate over. - internal IList PerCallPolicies { get; } = new List(); - - /// - /// Gets each http retry call policies. - /// - /// A collection of http pipeline policy that may take multiple service requests to iterate over. - internal IList PerRetryPolicies { get; } = new List(); - /// /// Converts client options. /// @@ -106,47 +59,12 @@ public T Convert() { var newOptions = new T(); newOptions.Transport = Transport; - foreach (var pol in PerCallPolicies) - { - newOptions.AddPolicy(pol, HttpPipelinePosition.PerCall); - } - foreach (var pol in PerRetryPolicies) - { - newOptions.AddPolicy(pol, HttpPipelinePosition.PerRetry); - } + CopyPolicies(this, newOptions); return newOptions; } - /// - /// Adds a policy for Azure resource manager client http call. - /// - /// The http call policy in the pipeline. - /// The position of the http call policy in the pipeline. - /// If is null. - public new void AddPolicy(HttpPipelinePolicy policy, HttpPipelinePosition position) - { - if (policy is null) - throw new ArgumentNullException(nameof(policy)); - - // TODO policy lists are internal hence we don't have access to them by inheriting ClientOptions in this Assembly, this is a wrapper for now to convert to the concrete - // policy options. - switch (position) - { - case HttpPipelinePosition.PerCall: - PerCallPolicies.Add(policy); - break; - case HttpPipelinePosition.PerRetry: - PerRetryPolicies.Add(policy); - break; - default: - throw new ArgumentOutOfRangeException(nameof(position), position, null); - } - - base.AddPolicy(policy, position); - } - /// /// Gets override object. /// @@ -162,24 +80,31 @@ public object GetOverrideObject(Func objectConstructor) return _overrides.GetOrAdd(typeof(T), objectConstructor()); } - // Will be removed like AddPolicy when we move to azure core - private void Copy(ArmClientOptions other) + private static void CopyPolicies(ClientOptions source, ClientOptions dest) { - Transport = other.Transport; - foreach (var pol in other.PerCallPolicies) + var perCallPoliciesProperty = source.GetType().GetProperty("PerCallPolicies", BindingFlags.NonPublic | BindingFlags.Instance | BindingFlags.GetProperty); + var perCallPolicies = perCallPoliciesProperty.GetValue(source) as IList; + + foreach (var policy in perCallPolicies) { - AddPolicy(pol, HttpPipelinePosition.PerCall); + dest.AddPolicy(policy, HttpPipelinePosition.PerCall); } - foreach (var pol in other.PerRetryPolicies) + var perRetryPoliciesProperty = source.GetType().GetProperty("PerRetryPolicies", BindingFlags.NonPublic | BindingFlags.Instance | BindingFlags.GetProperty); + var perRetryPolicies = perRetryPoliciesProperty.GetValue(source) as IList; + + foreach (var policy in perRetryPolicies) { - AddPolicy(pol, HttpPipelinePosition.PerRetry); + dest.AddPolicy(policy, HttpPipelinePosition.PerRetry); } } internal ArmClientOptions Clone() { - ArmClientOptions copy = new ArmClientOptions(DefaultLocation, PerCallPolicies, PerRetryPolicies); + ArmClientOptions copy = new ArmClientOptions(DefaultLocation); + + CopyPolicies(this, copy); + copy.ApiVersions = ApiVersions.Clone(); copy.Transport = Transport; return copy; diff --git a/sdk/resourcemanager/Azure.ResourceManager.Core/tests/Unit/ArmClientOptionsTests.cs b/sdk/resourcemanager/Azure.ResourceManager.Core/tests/Unit/ArmClientOptionsTests.cs index edcc23fe9a6bd..f038add357d7f 100644 --- a/sdk/resourcemanager/Azure.ResourceManager.Core/tests/Unit/ArmClientOptionsTests.cs +++ b/sdk/resourcemanager/Azure.ResourceManager.Core/tests/Unit/ArmClientOptionsTests.cs @@ -21,8 +21,6 @@ public void ValidateClone() Assert.IsFalse(ReferenceEquals(options1.Diagnostics, options2.Diagnostics)); Assert.IsFalse(ReferenceEquals(options1.Retry, options2.Retry)); Assert.IsFalse(ReferenceEquals(options1.ApiVersions, options2.ApiVersions)); - Assert.IsFalse(ReferenceEquals(options1.PerCallPolicies, options2.PerCallPolicies)); - Assert.IsFalse(ReferenceEquals(options1.PerRetryPolicies, options2.PerRetryPolicies)); } [TestCase] @@ -77,10 +75,6 @@ public void MultiClientSeparateVersions() public void TestClientOptionsParamCheck() { Assert.Throws(() => { new ArmClientOptions(null); }); - Assert.Throws(() => { new ArmClientOptions(null, null); }); - - var options = new ArmClientOptions(); - Assert.Throws(() => { options.AddPolicy(null, HttpPipelinePosition.PerCall); }); } [TestCase]