diff --git a/Microsoft.Azure.Cosmos.Encryption/src/AssemblyInfo.cs b/Microsoft.Azure.Cosmos.Encryption/src/AssemblyInfo.cs new file mode 100644 index 0000000000..d2a8aae49f --- /dev/null +++ b/Microsoft.Azure.Cosmos.Encryption/src/AssemblyInfo.cs @@ -0,0 +1,8 @@ +//------------------------------------------------------------ +// Copyright (c) Microsoft Corporation. All rights reserved. +//------------------------------------------------------------ + +using System.Runtime.CompilerServices; + +[assembly: InternalsVisibleTo("Microsoft.Azure.Cosmos.Encryption.Tests" + Microsoft.Azure.Cosmos.Encryption.AssemblyKeys.TestPublicKey)] +[assembly: InternalsVisibleTo("Microsoft.Azure.Cosmos.Encryption.EmulatorTests" + Microsoft.Azure.Cosmos.Encryption.AssemblyKeys.TestPublicKey)] \ No newline at end of file diff --git a/Microsoft.Azure.Cosmos.Encryption/src/AssemblyKeys.cs b/Microsoft.Azure.Cosmos.Encryption/src/AssemblyKeys.cs deleted file mode 100644 index 94e4ce2ed4..0000000000 --- a/Microsoft.Azure.Cosmos.Encryption/src/AssemblyKeys.cs +++ /dev/null @@ -1,14 +0,0 @@ -//------------------------------------------------------------ -// Copyright (c) Microsoft Corporation. All rights reserved. -//------------------------------------------------------------ - -using System.Runtime.CompilerServices; - -[assembly: InternalsVisibleTo("Microsoft.Azure.Cosmos.Encryption.Tests" + AssemblyKeys.TestPublicKey)] -[assembly: InternalsVisibleTo("Microsoft.Azure.Cosmos.Encryption.EmulatorTests" + AssemblyKeys.TestPublicKey)] - -internal static class AssemblyKeys -{ - /// TestPublicKey is an unsupported strong key for testing and internal use only - internal const string TestPublicKey = ", PublicKey=0024000004800000940000000602000000240000525341310004000001000100197c25d0a04f73cb271e8181dba1c0c713df8deebb25864541a66670500f34896d280484b45fe1ff6c29f2ee7aa175d8bcbd0c83cc23901a894a86996030f6292ce6eda6e6f3e6c74b3c5a3ded4903c951e6747e6102969503360f7781bf8bf015058eb89b7621798ccc85aaca036ff1bc1556bb7f62de15908484886aa8bbae"; -} \ No newline at end of file diff --git a/Microsoft.Azure.Cosmos.Encryption/src/DecryptableFeedResponse.cs b/Microsoft.Azure.Cosmos.Encryption/src/DecryptableFeedResponse.cs index 849e50b67a..d47d29530d 100644 --- a/Microsoft.Azure.Cosmos.Encryption/src/DecryptableFeedResponse.cs +++ b/Microsoft.Azure.Cosmos.Encryption/src/DecryptableFeedResponse.cs @@ -49,11 +49,7 @@ internal static DecryptableFeedResponse CreateResponse( using (responseMessage) { - // ReadFeed can return 304 in some scenarios (for example Change Feed) - if (responseMessage.StatusCode != HttpStatusCode.NotModified) - { - responseMessage.EnsureSuccessStatusCode(); - } + responseMessage.EnsureSuccessStatusCode(); return new DecryptableFeedResponse( responseMessage, diff --git a/Microsoft.Azure.Cosmos.Encryption/src/EncryptionContainer.cs b/Microsoft.Azure.Cosmos.Encryption/src/EncryptionContainer.cs index 958bcd3bc6..9f004623c5 100644 --- a/Microsoft.Azure.Cosmos.Encryption/src/EncryptionContainer.cs +++ b/Microsoft.Azure.Cosmos.Encryption/src/EncryptionContainer.cs @@ -800,75 +800,6 @@ public override Task> GetFeedRangesAsync( return this.container.GetFeedRangesAsync(cancellationToken); } - public override FeedIterator GetChangeFeedStreamIterator( - string continuationToken = null, - ChangeFeedRequestOptions changeFeedRequestOptions = null) - { - return new EncryptionFeedIterator( - this.container.GetChangeFeedStreamIterator( - continuationToken, - changeFeedRequestOptions), - this.Encryptor, - this.CosmosSerializer); - } - - public override FeedIterator GetChangeFeedStreamIterator( - FeedRange feedRange, - ChangeFeedRequestOptions changeFeedRequestOptions = null) - { - return new EncryptionFeedIterator( - this.container.GetChangeFeedStreamIterator( - feedRange, - changeFeedRequestOptions), - this.Encryptor, - this.CosmosSerializer); - } - - public override FeedIterator GetChangeFeedStreamIterator( - PartitionKey partitionKey, - ChangeFeedRequestOptions changeFeedRequestOptions = null) - { - return new EncryptionFeedIterator( - this.container.GetChangeFeedStreamIterator( - partitionKey, - changeFeedRequestOptions), - this.Encryptor, - this.CosmosSerializer); - } - - public override FeedIterator GetChangeFeedIterator( - string continuationToken = null, - ChangeFeedRequestOptions changeFeedRequestOptions = null) - { - return new EncryptionFeedIterator( - (EncryptionFeedIterator)this.GetChangeFeedStreamIterator( - continuationToken, - changeFeedRequestOptions), - this.ResponseFactory); - } - - public override FeedIterator GetChangeFeedIterator( - FeedRange feedRange, - ChangeFeedRequestOptions changeFeedRequestOptions = null) - { - return new EncryptionFeedIterator( - (EncryptionFeedIterator)this.GetChangeFeedStreamIterator( - feedRange, - changeFeedRequestOptions), - this.ResponseFactory); - } - - public override FeedIterator GetChangeFeedIterator( - PartitionKey partitionKey, - ChangeFeedRequestOptions changeFeedRequestOptions = null) - { - return new EncryptionFeedIterator( - (EncryptionFeedIterator)this.GetChangeFeedStreamIterator( - partitionKey, - changeFeedRequestOptions), - this.ResponseFactory); - } - public override Task> GetPartitionKeyRangesAsync( FeedRange feedRange, CancellationToken cancellationToken = default) @@ -906,5 +837,35 @@ public override FeedIterator GetItemQueryIterator( requestOptions), this.ResponseFactory); } + + public override ChangeFeedEstimator GetChangeFeedEstimator( + string processorName, + Container leaseContainer) + { + return this.container.GetChangeFeedEstimator(processorName, leaseContainer); + } + + public override FeedIterator GetChangeFeedStreamIterator( + ChangeFeedStartFrom changeFeedStartFrom, + ChangeFeedRequestOptions changeFeedRequestOptions = null) + { + return new EncryptionFeedIterator( + this.container.GetChangeFeedStreamIterator( + changeFeedStartFrom, + changeFeedRequestOptions), + this.Encryptor, + this.CosmosSerializer); + } + + public override FeedIterator GetChangeFeedIterator( + ChangeFeedStartFrom changeFeedStartFrom, + ChangeFeedRequestOptions changeFeedRequestOptions = null) + { + return new EncryptionFeedIterator( + (EncryptionFeedIterator)this.GetChangeFeedStreamIterator( + changeFeedStartFrom, + changeFeedRequestOptions), + this.ResponseFactory); + } } } \ No newline at end of file diff --git a/Microsoft.Azure.Cosmos.Encryption/src/EncryptionTransactionalBatch.cs b/Microsoft.Azure.Cosmos.Encryption/src/EncryptionTransactionalBatch.cs index a166a8ecd6..1f2d8c8e75 100644 --- a/Microsoft.Azure.Cosmos.Encryption/src/EncryptionTransactionalBatch.cs +++ b/Microsoft.Azure.Cosmos.Encryption/src/EncryptionTransactionalBatch.cs @@ -5,6 +5,7 @@ namespace Microsoft.Azure.Cosmos.Encryption { using System; + using System.Collections.Generic; using System.IO; using System.Threading; using System.Threading.Tasks; @@ -199,26 +200,60 @@ public override async Task ExecuteAsync( using (diagnosticsContext.CreateScope("TransactionalBatch.ExecuteAsync")) { TransactionalBatchResponse response = await this.transactionalBatch.ExecuteAsync(cancellationToken); + return await this.DecryptTransactionalBatchResponseAsync( + response, + diagnosticsContext, + cancellationToken); + } + } + + public override async Task ExecuteAsync( + TransactionalBatchRequestOptions requestOptions, + CancellationToken cancellationToken = default) + { + CosmosDiagnosticsContext diagnosticsContext = CosmosDiagnosticsContext.Create(options: null); + using (diagnosticsContext.CreateScope("TransactionalBatch.ExecuteAsync.WithRequestOptions")) + { + TransactionalBatchResponse response = await this.transactionalBatch.ExecuteAsync(requestOptions, cancellationToken); + return await this.DecryptTransactionalBatchResponseAsync( + response, + diagnosticsContext, + cancellationToken); + } + } + + private async Task DecryptTransactionalBatchResponseAsync( + TransactionalBatchResponse response, + CosmosDiagnosticsContext diagnosticsContext, + CancellationToken cancellationToken) + { + List decryptedTransactionalBatchOperationResults = new List(); - if (response.IsSuccessStatusCode) + if (response.IsSuccessStatusCode) + { + for (int index = 0; index < response.Count; index++) { - for (int index = 0; index < response.Count; index++) + TransactionalBatchOperationResult result = response[index]; + + if (result.ResourceStream != null) { - TransactionalBatchOperationResult result = response[index]; - - if (result.ResourceStream != null) - { - (result.ResourceStream, _) = await EncryptionProcessor.DecryptAsync( - result.ResourceStream, - this.encryptor, - diagnosticsContext, - cancellationToken); - } + (Stream decryptedStream, _) = await EncryptionProcessor.DecryptAsync( + result.ResourceStream, + this.encryptor, + diagnosticsContext, + cancellationToken); + + result = new EncryptionTransactionalBatchOperationResult(response[index], decryptedStream); } - } - return response; + decryptedTransactionalBatchOperationResults.Add(result); + } } + + return new EncryptionTransactionalBatchResponse( + decryptedTransactionalBatchOperationResults, + response, + this.cosmosSerializer); } } } diff --git a/Microsoft.Azure.Cosmos.Encryption/src/EncryptionTransactionalBatchOperationResult.cs b/Microsoft.Azure.Cosmos.Encryption/src/EncryptionTransactionalBatchOperationResult.cs new file mode 100644 index 0000000000..5a3e8ce2cc --- /dev/null +++ b/Microsoft.Azure.Cosmos.Encryption/src/EncryptionTransactionalBatchOperationResult.cs @@ -0,0 +1,32 @@ +//------------------------------------------------------------ +// Copyright (c) Microsoft Corporation. All rights reserved. +//------------------------------------------------------------ + +namespace Microsoft.Azure.Cosmos.Encryption +{ + using System; + using System.IO; + using System.Net; + + internal sealed class EncryptionTransactionalBatchOperationResult : TransactionalBatchOperationResult + { + private readonly Stream encryptionResourceStream; + private readonly TransactionalBatchOperationResult response; + + public EncryptionTransactionalBatchOperationResult(TransactionalBatchOperationResult response, Stream encryptionResourceStream) + { + this.response = response; + this.encryptionResourceStream = encryptionResourceStream; + } + + public override Stream ResourceStream => this.encryptionResourceStream; + + public override HttpStatusCode StatusCode => this.response.StatusCode; + + public override bool IsSuccessStatusCode => this.response.IsSuccessStatusCode; + + public override string ETag => this.response.ETag; + + public override TimeSpan RetryAfter => this.response.RetryAfter; + } +} \ No newline at end of file diff --git a/Microsoft.Azure.Cosmos.Encryption/src/EncryptionTransactionalBatchOperationResult{T}.cs b/Microsoft.Azure.Cosmos.Encryption/src/EncryptionTransactionalBatchOperationResult{T}.cs new file mode 100644 index 0000000000..96b2f45081 --- /dev/null +++ b/Microsoft.Azure.Cosmos.Encryption/src/EncryptionTransactionalBatchOperationResult{T}.cs @@ -0,0 +1,24 @@ +//------------------------------------------------------------ +// Copyright (c) Microsoft Corporation. All rights reserved. +//------------------------------------------------------------ + +namespace Microsoft.Azure.Cosmos.Encryption +{ + internal sealed class EncryptionTransactionalBatchOperationResult : TransactionalBatchOperationResult + { + /// + /// Initializes a new instance of the class. + /// + /// BatchOperationResult with stream resource. + /// Deserialized resource. + internal EncryptionTransactionalBatchOperationResult(T resource) + { + this.Resource = resource; + } + + /// + /// Gets or sets the content of the resource. + /// + public override T Resource { get; set; } + } +} \ No newline at end of file diff --git a/Microsoft.Azure.Cosmos.Encryption/src/EncryptionTransactionalBatchResponse.cs b/Microsoft.Azure.Cosmos.Encryption/src/EncryptionTransactionalBatchResponse.cs new file mode 100644 index 0000000000..d7ee105dcc --- /dev/null +++ b/Microsoft.Azure.Cosmos.Encryption/src/EncryptionTransactionalBatchResponse.cs @@ -0,0 +1,79 @@ +//------------------------------------------------------------ +// Copyright (c) Microsoft Corporation. All rights reserved. +//------------------------------------------------------------ + +namespace Microsoft.Azure.Cosmos.Encryption +{ + using System; + using System.Collections.Generic; + using System.Net; + + internal sealed class EncryptionTransactionalBatchResponse : TransactionalBatchResponse + { + private readonly IReadOnlyList results; + private readonly TransactionalBatchResponse response; + private readonly CosmosSerializer cosmosSerializer; + private bool isDisposed = false; + + public EncryptionTransactionalBatchResponse( + IReadOnlyList results, + TransactionalBatchResponse response, + CosmosSerializer cosmosSerializer) + { + this.results = results; + this.response = response; + this.cosmosSerializer = cosmosSerializer; + } + + public override TransactionalBatchOperationResult this[int index] => this.results[index]; + + public override TransactionalBatchOperationResult GetOperationResultAtIndex(int index) + { + TransactionalBatchOperationResult result = this.results[index]; + + T resource = default; + if (result.ResourceStream != null) + { + resource = this.cosmosSerializer.FromStream(result.ResourceStream); + } + + return new EncryptionTransactionalBatchOperationResult(resource); + } + + public override IEnumerator GetEnumerator() + { + return this.results.GetEnumerator(); + } + + public override Headers Headers => this.response.Headers; + + public override string ActivityId => this.response.ActivityId; + + public override double RequestCharge => this.response.RequestCharge; + + public override TimeSpan? RetryAfter => this.response.RetryAfter; + + public override HttpStatusCode StatusCode => this.response.StatusCode; + + public override string ErrorMessage => this.response.ErrorMessage; + + public override bool IsSuccessStatusCode => this.response.IsSuccessStatusCode; + + public override int Count => this.results?.Count ?? 0; + + public override CosmosDiagnostics Diagnostics => this.response.Diagnostics; + + protected override void Dispose(bool disposing) + { + if (disposing && !this.isDisposed) + { + this.isDisposed = true; + + if (this.response != null) + { + this.response.Dispose(); + } + } + } + } +} \ No newline at end of file diff --git a/Microsoft.Azure.Cosmos.Encryption/src/Microsoft.Azure.Cosmos.Encryption.csproj b/Microsoft.Azure.Cosmos.Encryption/src/Microsoft.Azure.Cosmos.Encryption.csproj index eaa28d35f8..84077d58c0 100644 --- a/Microsoft.Azure.Cosmos.Encryption/src/Microsoft.Azure.Cosmos.Encryption.csproj +++ b/Microsoft.Azure.Cosmos.Encryption/src/Microsoft.Azure.Cosmos.Encryption.csproj @@ -26,7 +26,7 @@ - + diff --git a/Microsoft.Azure.Cosmos.Encryption/src/Mirrored/AssemblyKeys.cs b/Microsoft.Azure.Cosmos.Encryption/src/Mirrored/AssemblyKeys.cs new file mode 100644 index 0000000000..c0297a0cd6 --- /dev/null +++ b/Microsoft.Azure.Cosmos.Encryption/src/Mirrored/AssemblyKeys.cs @@ -0,0 +1,12 @@ +//------------------------------------------------------------ +// Copyright (c) Microsoft Corporation. All rights reserved. +//------------------------------------------------------------ + +namespace Microsoft.Azure.Cosmos.Encryption +{ + internal static class AssemblyKeys + { + /// TestPublicKey is an unsupported strong key for testing and internal use only + internal const string TestPublicKey = ", PublicKey=0024000004800000940000000602000000240000525341310004000001000100197c25d0a04f73cb271e8181dba1c0c713df8deebb25864541a66670500f34896d280484b45fe1ff6c29f2ee7aa175d8bcbd0c83cc23901a894a86996030f6292ce6eda6e6f3e6c74b3c5a3ded4903c951e6747e6102969503360f7781bf8bf015058eb89b7621798ccc85aaca036ff1bc1556bb7f62de15908484886aa8bbae"; + } +} \ No newline at end of file diff --git a/Microsoft.Azure.Cosmos.Encryption/src/Mirrored/AsyncCache.cs b/Microsoft.Azure.Cosmos.Encryption/src/Mirrored/AsyncCache.cs new file mode 100644 index 0000000000..87dd67b3c5 --- /dev/null +++ b/Microsoft.Azure.Cosmos.Encryption/src/Mirrored/AsyncCache.cs @@ -0,0 +1,239 @@ +//------------------------------------------------------------ +// Copyright (c) Microsoft Corporation. All rights reserved. +//------------------------------------------------------------ + +namespace Microsoft.Azure.Cosmos.Encryption +{ + using System; + using System.Collections.Concurrent; + using System.Collections.Generic; + using System.Diagnostics; + using System.Threading; + using System.Threading.Tasks; + + /// + /// Cache which supports asynchronous value initialization. + /// It ensures that for given key only single inintialization funtion is running at any point in time. + /// + /// Type of keys. + /// Type of values. + [System.Diagnostics.CodeAnalysis.SuppressMessage("Usage", "VSTHRD105:Avoid method overloads that assume TaskScheduler.Current", Justification = "Mirrored file.")] + internal sealed class AsyncCache + { + private readonly IEqualityComparer valueEqualityComparer; + private readonly IEqualityComparer keyEqualityComparer; + + private ConcurrentDictionary> values; + + public AsyncCache(IEqualityComparer valueEqualityComparer, IEqualityComparer keyEqualityComparer = null) + { + this.keyEqualityComparer = keyEqualityComparer ?? EqualityComparer.Default; + this.values = new ConcurrentDictionary>(this.keyEqualityComparer); + this.valueEqualityComparer = valueEqualityComparer; + } + + public AsyncCache() + : this(EqualityComparer.Default) + { + } + + public ICollection Keys => this.values.Keys; + + public void Set(TKey key, TValue value) + { + AsyncLazy lazyValue = new AsyncLazy(value); + + // Access it to mark as created+completed, so that further calls to getasync do not overwrite. +#pragma warning disable VSTHRD002 // Avoid problematic synchronous waits + TValue x = lazyValue.Value.Result; +#pragma warning restore VSTHRD002 // Avoid problematic synchronous waits + + this.values.AddOrUpdate(key, lazyValue, (k, existingValue) => + { + // Observe all exceptions thrown for existingValue. + if (existingValue.IsValueCreated) + { + Task unused = existingValue.Value.ContinueWith(c => c.Exception, TaskContinuationOptions.OnlyOnFaulted); + } + + return lazyValue; + }); + } + + /// + /// + /// Gets value corresponding to . + /// + /// + /// If another initialization function is already running, new initialization function will not be started. + /// The result will be result of currently running initialization function. + /// + /// + /// If previous initialization function is successfully completed - value returned by it will be returned unless + /// it is equal to , in which case new initialization function will be started. + /// + /// + /// If previous initialization function failed - new one will be launched. + /// + /// + /// Key for which to get a value. + /// Value which is obsolete and needs to be refreshed. + /// Initialization function. + /// Cancellation token. + /// Skip cached value and generate new value. + /// Cached value or value returned by initialization function. + public async Task GetAsync( + TKey key, + TValue obsoleteValue, + Func> singleValueInitFunc, + CancellationToken cancellationToken, + bool forceRefresh = false) + { + cancellationToken.ThrowIfCancellationRequested(); + + if (this.values.TryGetValue(key, out AsyncLazy initialLazyValue)) + { + // If we haven't computed the value or we're currently computing it, then return it... + if (!initialLazyValue.IsValueCreated || !initialLazyValue.Value.IsCompleted) + { + try + { + return await initialLazyValue.Value; + } + + // It does not matter to us if this instance of the task throws - the lambda that failed was provided by a different caller. + // The exception that we see here will be handled/logged by whatever caller provided the failing lambda, if any. Our part is catching and observing it. + // As such, we discard this exception and will retry with our own lambda below, for which we will let exception bubble up. + catch + { + } + } + + // Don't check Task if there's an exception or it's been canceled. Accessing Task.Exception marks it as observed, which we want. + else if (initialLazyValue.Value.Exception == null && !initialLazyValue.Value.IsCanceled) + { + TValue cachedValue = await initialLazyValue.Value; + + // If not forcing refresh or obsolete value, use cached value. + if (!forceRefresh && !this.valueEqualityComparer.Equals(cachedValue, obsoleteValue)) + { + return cachedValue; + } + } + } + + AsyncLazy newLazyValue = new AsyncLazy(singleValueInitFunc, cancellationToken); + + // Update the new task in the cache - compare-and-swap style. + AsyncLazy actualValue = this.values.AddOrUpdate( + key, + newLazyValue, + (existingKey, existingValue) => object.ReferenceEquals(existingValue, initialLazyValue) ? newLazyValue : existingValue); + + // Task starts running here. + Task generator = actualValue.Value; + + // Even if the current thread goes away, all exceptions will be observed. + Task unused = generator.ContinueWith(c => c.Exception, TaskContinuationOptions.OnlyOnFaulted); + + return await generator; + } + + public void Remove(TKey key) + { + if (this.values.TryRemove(key, out AsyncLazy initialLazyValue) && initialLazyValue.IsValueCreated) + { + // Observe all exceptions thrown. + Task unused = initialLazyValue.Value.ContinueWith(c => c.Exception, TaskContinuationOptions.OnlyOnFaulted); + } + } + + public bool TryRemoveIfCompleted(TKey key) + { + if (this.values.TryGetValue(key, out AsyncLazy initialLazyValue) && initialLazyValue.IsValueCreated && initialLazyValue.Value.IsCompleted) + { + // Accessing Exception marks as observed. + _ = initialLazyValue.Value.Exception; + + // This is a nice trick to do "atomic remove if value not changed". + // ConcurrentDictionary inherits from ICollection>, which allows removal of specific key value pair, instead of removal just by key. + ICollection>> valuesAsCollection = this.values as ICollection>>; + Debug.Assert(valuesAsCollection != null, "Values collection expected to implement ICollection>."); + return valuesAsCollection?.Remove(new KeyValuePair>(key, initialLazyValue)) ?? false; + } + + return false; + } + + /// + /// Remove value from cache and return it if present. + /// + /// Key + /// Value if present, default value if not present. + public async Task RemoveAsync(TKey key) + { + if (this.values.TryRemove(key, out AsyncLazy initialLazyValue)) + { + try + { + return await initialLazyValue.Value; + } + catch + { + } + } + + return default; + } + + public void Clear() + { + ConcurrentDictionary> newValues = new ConcurrentDictionary>(this.keyEqualityComparer); + ConcurrentDictionary> oldValues = Interlocked.Exchange(ref this.values, newValues); + + // Ensure all tasks are observed. + foreach (AsyncLazy value in oldValues.Values) + { + if (value.IsValueCreated) + { + Task unused = value.Value.ContinueWith(c => c.Exception, TaskContinuationOptions.OnlyOnFaulted); + } + } + + oldValues.Clear(); + } + + /// + /// Runs a background task that will started refreshing the cached value for a given key. + /// This observes the same logic as GetAsync - a running value will still take precedence over a call to this. + /// + /// Key. + /// Generator function. + public void BackgroundRefreshNonBlocking(TKey key, Func> singleValueInitFunc) + { + // Trigger background refresh of cached value. + // Fire and forget. + Task unused = Task.Factory.StartNewOnCurrentTaskSchedulerAsync(async () => + { + try + { + // If we don't have a value, or we have one that has completed running (i.e. if a value is currently being generated, we do nothing). + if (!this.values.TryGetValue(key, out AsyncLazy initialLazyValue) || (initialLazyValue.IsValueCreated && initialLazyValue.Value.IsCompleted)) + { + // Use GetAsync to trigger the generation of a value. + await this.GetAsync( + key, + default, // obsolete value unused since forceRefresh: true + singleValueInitFunc, + CancellationToken.None, + forceRefresh: true); + } + } + catch + { + // Observe all exceptions. + } + }).Unwrap(); + } + } +} \ No newline at end of file diff --git a/Microsoft.Azure.Cosmos.Encryption/src/Mirrored/AsyncLazy.cs b/Microsoft.Azure.Cosmos.Encryption/src/Mirrored/AsyncLazy.cs new file mode 100644 index 0000000000..a2de41cf46 --- /dev/null +++ b/Microsoft.Azure.Cosmos.Encryption/src/Mirrored/AsyncLazy.cs @@ -0,0 +1,28 @@ +//------------------------------------------------------------ +// Copyright (c) Microsoft Corporation. All rights reserved. +//------------------------------------------------------------ + +namespace Microsoft.Azure.Cosmos.Encryption +{ + using System; + using System.Threading; + using System.Threading.Tasks; + + internal sealed class AsyncLazy : Lazy> + { + public AsyncLazy(T value) + : base(() => Task.FromResult(value)) + { + } + + public AsyncLazy(Func valueFactory, CancellationToken cancellationToken) + : base(() => Task.Factory.StartNewOnCurrentTaskSchedulerAsync(valueFactory, cancellationToken)) // Task.Factory.StartNew() allows specifying task scheduler to use which is critical for compute gateway to track physical consumption. + { + } + + public AsyncLazy(Func> taskFactory, CancellationToken cancellationToken) + : base(() => Task.Factory.StartNewOnCurrentTaskSchedulerAsync(taskFactory, cancellationToken).Unwrap()) // Task.Factory.StartNew() allows specifying task scheduler to use which is critical for compute gateway to track physical consumption. + { + } + } +} \ No newline at end of file diff --git a/Microsoft.Azure.Cosmos.Encryption/src/Mirrored/BackoffRetryUtility.cs b/Microsoft.Azure.Cosmos.Encryption/src/Mirrored/BackoffRetryUtility.cs new file mode 100644 index 0000000000..3139f8dade --- /dev/null +++ b/Microsoft.Azure.Cosmos.Encryption/src/Mirrored/BackoffRetryUtility.cs @@ -0,0 +1,167 @@ +//------------------------------------------------------------ +// Copyright (c) Microsoft Corporation. All rights reserved. +//------------------------------------------------------------ + +namespace Microsoft.Azure.Cosmos.Encryption +{ + using System; + using System.Diagnostics; + using System.Runtime.ExceptionServices; + using System.Threading; + using System.Threading.Tasks; + + internal static class BackoffRetryUtility + { + public const string ExceptionSourceToIgnoreForIgnoreForRetry = "BackoffRetryUtility"; + + public static Task ExecuteAsync( + Func> callbackMethod, + IRetryPolicy retryPolicy, + CancellationToken cancellationToken = default, + Action preRetryCallback = null) + { + return ExecuteRetryAsync( + () => callbackMethod(), + (Exception exception, CancellationToken token) => retryPolicy.ShouldRetryAsync(exception, cancellationToken), + null, + TimeSpan.Zero, + cancellationToken, + preRetryCallback); + } + + public static Task ExecuteAsync( + Func> callbackMethod, + IRetryPolicy retryPolicy, + CancellationToken cancellationToken = default, + Action preRetryCallback = null) + { + TPolicyArg1 policyArg1 = retryPolicy.InitialArgumentValue; + + return ExecuteRetryAsync( + () => callbackMethod(policyArg1), + async (exception, token) => + { + ShouldRetryResult result = await retryPolicy.ShouldRetryAsync(exception, cancellationToken); + policyArg1 = result.PolicyArg1; + return result; + }, + null, + TimeSpan.Zero, + cancellationToken, + preRetryCallback); + } + + public static Task ExecuteAsync( + Func> callbackMethod, + IRetryPolicy retryPolicy, + Func> inBackoffAlternateCallbackMethod, + TimeSpan minBackoffForInBackoffCallback, + CancellationToken cancellationToken = default, + Action preRetryCallback = null) + { + Func> inBackoffAlternateCallbackMethodAsync = null; + if (inBackoffAlternateCallbackMethod != null) + { + inBackoffAlternateCallbackMethodAsync = () => inBackoffAlternateCallbackMethod(); + } + + return ExecuteRetryAsync( + () => callbackMethod(), + (Exception exception, CancellationToken token) => retryPolicy.ShouldRetryAsync(exception, cancellationToken), + inBackoffAlternateCallbackMethodAsync, + minBackoffForInBackoffCallback, + cancellationToken, + preRetryCallback); + } + + public static Task ExecuteAsync( + Func> callbackMethod, + IRetryPolicy retryPolicy, + Func> inBackoffAlternateCallbackMethod, + TimeSpan minBackoffForInBackoffCallback, + CancellationToken cancellationToken = default, + Action preRetryCallback = null) + { + TPolicyArg1 policyArg1 = retryPolicy.InitialArgumentValue; + + Func> inBackoffAlternateCallbackMethodAsync = null; + if (inBackoffAlternateCallbackMethod != null) + { + inBackoffAlternateCallbackMethodAsync = () => inBackoffAlternateCallbackMethod(policyArg1); + } + + return ExecuteRetryAsync( + () => callbackMethod(policyArg1), + async (exception, token) => + { + ShouldRetryResult result = await retryPolicy.ShouldRetryAsync(exception, cancellationToken); + policyArg1 = result.PolicyArg1; + return result; + }, + inBackoffAlternateCallbackMethodAsync, + minBackoffForInBackoffCallback, + cancellationToken, + preRetryCallback); + } + + internal static async Task ExecuteRetryAsync( + Func> callbackMethod, + Func> callShouldRetry, + Func> inBackoffAlternateCallbackMethod, + TimeSpan minBackoffForInBackoffCallback, + CancellationToken cancellationToken, + Action preRetryCallback = null) + { + while (true) + { + ExceptionDispatchInfo exception; + try + { + cancellationToken.ThrowIfCancellationRequested(); + + return await callbackMethod(); + } + catch (Exception ex) + { + await Task.Yield(); + exception = ExceptionDispatchInfo.Capture(ex); + } + + // Don't retry if caller specified cancellation token was signaled. + // Note that we can't simply key off of OperationCancelledException + // here as this can be thrown independent of caller's CancellationToken + // being signaled. For example, WinFab throws OperationCancelledException + // when it gets E_ABORT from native code. + cancellationToken.ThrowIfCancellationRequested(); + + ShouldRetryResult result = await callShouldRetry(exception.SourceException, cancellationToken); + + result.ThrowIfDoneTrying(exception); + + TimeSpan backoffTime = result.BackoffTime; + if (inBackoffAlternateCallbackMethod != null && result.BackoffTime >= minBackoffForInBackoffCallback) + { + Stopwatch stopwatch = new Stopwatch(); + try + { + stopwatch.Start(); + return await inBackoffAlternateCallbackMethod(); + } + catch (Exception) + { + stopwatch.Stop(); + } + + backoffTime = result.BackoffTime > stopwatch.Elapsed ? result.BackoffTime - stopwatch.Elapsed : TimeSpan.Zero; + } + + preRetryCallback?.Invoke(exception.SourceException); + + if (backoffTime != TimeSpan.Zero) + { + await Task.Delay(backoffTime, cancellationToken); + } + } + } + } +} \ No newline at end of file diff --git a/Microsoft.Azure.Cosmos.Encryption/src/Mirrored/IRetryPolicy.cs b/Microsoft.Azure.Cosmos.Encryption/src/Mirrored/IRetryPolicy.cs new file mode 100644 index 0000000000..92e0d9f88c --- /dev/null +++ b/Microsoft.Azure.Cosmos.Encryption/src/Mirrored/IRetryPolicy.cs @@ -0,0 +1,122 @@ +//------------------------------------------------------------ +// Copyright (c) Microsoft Corporation. All rights reserved. +//------------------------------------------------------------ + +namespace Microsoft.Azure.Cosmos.Encryption +{ + using System; + using System.Runtime.ExceptionServices; + using System.Threading; + using System.Threading.Tasks; + + [System.Diagnostics.CodeAnalysis.SuppressMessage("StyleCop.CSharp.MaintainabilityRules", "SA1402:File may only contain a single type", Justification = "Mirrored file.")] + [System.Diagnostics.CodeAnalysis.SuppressMessage("StyleCop.CSharp.DocumentationRules", "SA1649:File name should match first type name", Justification = "Mirrored file.")] + internal class ShouldRetryResult + { + private static readonly ShouldRetryResult EmptyNoRetry = new ShouldRetryResult { ShouldRetry = false }; + + protected ShouldRetryResult() + { + } + + public bool ShouldRetry { get; protected set; } + + /// + /// Gets or sets how long to wait before next retry. 0 indicates retry immediately. + /// + public TimeSpan BackoffTime { get; protected set; } + + /// + /// Gets or sets exception to throw. + /// + public Exception ExceptionToThrow { get; protected set; } + + public void ThrowIfDoneTrying(ExceptionDispatchInfo capturedException) + { + if (this.ShouldRetry) + { + return; + } + + if (this.ExceptionToThrow == null) + { + capturedException.Throw(); + } + + if (capturedException != null && object.ReferenceEquals( + this.ExceptionToThrow, capturedException.SourceException)) + { + capturedException.Throw(); + } + else + { + throw this.ExceptionToThrow; + } + } + + public static ShouldRetryResult NoRetry(Exception exception = null) + { + if (exception == null) + { + return ShouldRetryResult.EmptyNoRetry; + } + + return new ShouldRetryResult { ShouldRetry = false, ExceptionToThrow = exception }; + } + + public static ShouldRetryResult RetryAfter(TimeSpan backoffTime) + { + return new ShouldRetryResult { ShouldRetry = true, BackoffTime = backoffTime }; + } + } + + [System.Diagnostics.CodeAnalysis.SuppressMessage("StyleCop.CSharp.MaintainabilityRules", "SA1402:File may only contain a single type", Justification = "Mirrored file.")] + internal class ShouldRetryResult : ShouldRetryResult + { + /// + /// Gets argument to be passed to the callback method. + /// + public TPolicyArg1 PolicyArg1 { get; private set; } + + public static new ShouldRetryResult NoRetry(Exception exception = null) + { + return new ShouldRetryResult { ShouldRetry = false, ExceptionToThrow = exception }; + } + + public static ShouldRetryResult RetryAfter(TimeSpan backoffTime, TPolicyArg1 policyArg1) + { + return new ShouldRetryResult { ShouldRetry = true, BackoffTime = backoffTime, PolicyArg1 = policyArg1 }; + } + } + + internal abstract class IRetryPolicy + { + /// + /// Method that is called to determine from the policy that needs to retry on the exception + /// + /// Exception during the callback method invocation + /// Cancellation Token + /// If the retry needs to be attempted or not + public abstract Task ShouldRetryAsync(Exception exception, CancellationToken cancellationToken); + } + + [System.Diagnostics.CodeAnalysis.SuppressMessage("StyleCop.CSharp.MaintainabilityRules", "SA1402:File may only contain a single type", Justification = "Mirrored file.")] + internal abstract class IRetryPolicy + { + /// + /// Method that is called to determine from the policy that needs to retry on the exception + /// + /// Exception during the callback method invocation + /// Cancelltion Token + /// If the retry needs to be attempted or not + public abstract Task> ShouldRetryAsync(Exception exception, CancellationToken cancellationToken); + + /// + /// Gets initial value of the template argument + /// + public abstract TPolicyArg1 InitialArgumentValue + { + get; + } + } +} \ No newline at end of file diff --git a/Microsoft.Azure.Cosmos.Encryption/src/Mirrored/TaskFactoryExtensions.cs b/Microsoft.Azure.Cosmos.Encryption/src/Mirrored/TaskFactoryExtensions.cs new file mode 100644 index 0000000000..c55db1d218 --- /dev/null +++ b/Microsoft.Azure.Cosmos.Encryption/src/Mirrored/TaskFactoryExtensions.cs @@ -0,0 +1,42 @@ +//------------------------------------------------------------ +// Copyright (c) Microsoft Corporation. All rights reserved. +//------------------------------------------------------------ + +namespace Microsoft.Azure.Cosmos.Encryption +{ + using System; + using System.Threading; + using System.Threading.Tasks; + + /// + /// Extensions to task factory methods that are meant to be used as patterns of invocation of asynchronous operations + /// inside compute gateway ensuring continuity of execution on the current task scheduler. + /// Task scheduler is used to track resource consumption per tenant so it is critical that all async activity + /// pertaining to the tenant runs on the same task scheduler. + /// + internal static class TaskFactoryExtensions + { + /// Creates and starts a on the current task scheduler. + /// The started . + /// Instance of the to use for starting the task. + /// A function delegate that returns the future result to be available through the . + /// The type of the result available through the . + /// The exception that is thrown when the argument is null. + public static Task StartNewOnCurrentTaskSchedulerAsync(this TaskFactory taskFactory, Func function) + { + return taskFactory.StartNew(function, default, TaskCreationOptions.None, TaskScheduler.Current); + } + + /// Creates and starts a on the current task scheduler. + /// The started . + /// Instance of the to use for starting the task. + /// A function delegate that returns the future result to be available through the . + /// The that will be assigned to the new task. + /// The type of the result available through the . + /// The exception that is thrown when the argument is null. + public static Task StartNewOnCurrentTaskSchedulerAsync(this TaskFactory taskFactory, Func function, CancellationToken cancellationToken) + { + return taskFactory.StartNew(function, cancellationToken, TaskCreationOptions.None, TaskScheduler.Current); + } + } +} \ No newline at end of file diff --git a/Microsoft.Azure.Cosmos.Encryption/src/Mirrored/TaskHelper.cs b/Microsoft.Azure.Cosmos.Encryption/src/Mirrored/TaskHelper.cs new file mode 100644 index 0000000000..e67337b9e7 --- /dev/null +++ b/Microsoft.Azure.Cosmos.Encryption/src/Mirrored/TaskHelper.cs @@ -0,0 +1,102 @@ +//------------------------------------------------------------ +// Copyright (c) Microsoft Corporation. All rights reserved. +//------------------------------------------------------------ +namespace Microsoft.Azure.Cosmos.Encryption +{ + using System; + using System.Threading; + using System.Threading.Tasks; + + /// + /// The helper function relates to the async Task. + /// + internal static class TaskHelper + { + public static Task InlineIfPossibleAsync(Func function, IRetryPolicy retryPolicy, CancellationToken cancellationToken = default) + { + if (SynchronizationContext.Current == null) + { + if (retryPolicy == null) + { + // shortcut + return function(); + } + else + { + return BackoffRetryUtility.ExecuteAsync( + async () => + { + await function(); + return 0; + }, retryPolicy, + cancellationToken); + } + } + else + { + if (retryPolicy == null) + { + // shortcut + return Task.Run(function); + } + else + { + return Task.Run(() => BackoffRetryUtility.ExecuteAsync( + async () => + { + await function(); + return 0; + }, retryPolicy, + cancellationToken)); + } + } + } + +#pragma warning disable VSTHRD200 // Use "Async" suffix for async methods + public static Task InlineIfPossible(Func> function, IRetryPolicy retryPolicy, CancellationToken cancellationToken = default) +#pragma warning restore VSTHRD200 // Use "Async" suffix for async methods + { + if (SynchronizationContext.Current == null) + { + if (retryPolicy == null) + { + // shortcut + return function(); + } + else + { + return BackoffRetryUtility.ExecuteAsync( + () => function(), + retryPolicy, + cancellationToken); + } + } + else + { + if (retryPolicy == null) + { + // shortcut + return Task.Run(function); + } + else + { + return Task.Run(() => BackoffRetryUtility.ExecuteAsync( + () => function(), + retryPolicy, + cancellationToken)); + } + } + } + + public static Task RunInlineIfNeededAsync(Func> task) + { + if (SynchronizationContext.Current == null) + { + return task(); + } + + // Used on NETFX applications with SynchronizationContext when doing locking calls + return Task.Run(task); + } + } +} \ No newline at end of file diff --git a/Microsoft.Azure.Cosmos.Encryption/tests/EmulatorTests/EncryptionTests.cs b/Microsoft.Azure.Cosmos.Encryption/tests/EmulatorTests/EncryptionTests.cs index 1b1ba3c274..0fa46c540c 100644 --- a/Microsoft.Azure.Cosmos.Encryption/tests/EmulatorTests/EncryptionTests.cs +++ b/Microsoft.Azure.Cosmos.Encryption/tests/EmulatorTests/EncryptionTests.cs @@ -9,6 +9,7 @@ namespace Microsoft.Azure.Cosmos.Encryption.EmulatorTests using System.IO; using System.Linq; using System.Net; + using System.Text; using System.Threading; using System.Threading.Tasks; using global::Azure.Core; @@ -701,18 +702,22 @@ public async Task EncryptionHandleDecryptionFailure() // validate changeFeed handling FeedIterator changeIterator = EncryptionTests.encryptionContainer.GetChangeFeedIterator( - continuationToken: null, - new ChangeFeedRequestOptions() - { - StartTime = DateTime.MinValue.ToUniversalTime() - }); + ChangeFeedStartFrom.Beginning()); while (changeIterator.HasMoreResults) { - readDocsLazily = await changeIterator.ReadNextAsync(); - if (readDocsLazily.Resource != null) + try + { + readDocsLazily = await changeIterator.ReadNextAsync(); + if (readDocsLazily.Resource != null) + { + await this.ValidateLazyDecryptionResponse(readDocsLazily, dek2); + } + } + catch (CosmosException ex) { - await this.ValidateLazyDecryptionResponse(readDocsLazily, dek2); + Assert.IsTrue(ex.Message.Contains("Response status code does not indicate success: NotModified (304)")); + break; } } @@ -1003,7 +1008,7 @@ public async Task EncryptionBulkCrud() } [TestMethod] - public async Task EncryptionTransactionBatchCrud() + public async Task EncryptionTransactionalBatchCrud() { string partitionKey = "thePK"; string dek1 = EncryptionTests.dekId; @@ -1048,6 +1053,30 @@ public async Task EncryptionTransactionBatchCrud() Assert.AreEqual(HttpStatusCode.OK, batchResponse.StatusCode); + TransactionalBatchOperationResult doc1 = batchResponse.GetOperationResultAtIndex(0); + Assert.AreEqual(doc1ToCreate, doc1.Resource); + + TransactionalBatchOperationResult doc2 = batchResponse.GetOperationResultAtIndex(1); + Assert.AreEqual(doc2ToCreate, doc2.Resource); + + TransactionalBatchOperationResult doc3 = batchResponse.GetOperationResultAtIndex(2); + Assert.AreEqual(doc1ToReplace, doc3.Resource); + + TransactionalBatchOperationResult doc4 = batchResponse.GetOperationResultAtIndex(3); + Assert.AreEqual(doc3ToCreate, doc4.Resource); + + TransactionalBatchOperationResult doc5 = batchResponse.GetOperationResultAtIndex(4); + Assert.AreEqual(doc4ToCreate, doc5.Resource); + + TransactionalBatchOperationResult doc6 = batchResponse.GetOperationResultAtIndex(5); + Assert.AreEqual(doc2ToReplace, doc6.Resource); + + TransactionalBatchOperationResult doc7 = batchResponse.GetOperationResultAtIndex(6); + Assert.AreEqual(doc1ToUpsert, doc7.Resource); + + TransactionalBatchOperationResult doc8 = batchResponse.GetOperationResultAtIndex(8); + Assert.AreEqual(doc2ToUpsert, doc8.Resource); + await EncryptionTests.VerifyItemByReadAsync(EncryptionTests.encryptionContainer, doc1ToCreate); await EncryptionTests.VerifyItemByReadAsync(EncryptionTests.encryptionContainer, doc2ToCreate, dekId: dek2); await EncryptionTests.VerifyItemByReadAsync(EncryptionTests.encryptionContainer, doc3ToCreate, isDocDecrypted: false); @@ -1084,6 +1113,56 @@ public async Task EncryptionTransactionBatchCrud() await EncryptionTests.VerifyItemByReadAsync(EncryptionTests.itemContainer, doc2ToUpsert); } + [TestMethod] + public async Task EncryptionTransactionalBatchWithCustomSerializer() + { + CustomSerializer customSerializer = new CustomSerializer(); + CosmosClient clientWithCustomSerializer = TestCommon.CreateCosmosClient(builder => builder + .WithCustomSerializer(customSerializer) + .Build()); + + Database databaseWithCustomSerializer = clientWithCustomSerializer.GetDatabase(EncryptionTests.database.Id); + Container containerWithCustomSerializer = databaseWithCustomSerializer.GetContainer(EncryptionTests.itemContainer.Id); + Container encryptionContainerWithCustomSerializer = containerWithCustomSerializer.WithEncryptor(EncryptionTests.encryptor); + + string partitionKey = "thePK"; + string dek1 = EncryptionTests.dekId; + + TestDoc doc1ToCreate = TestDoc.Create(partitionKey); + + ItemResponse doc1ToReplaceCreateResponse = await EncryptionTests.CreateItemAsync(encryptionContainerWithCustomSerializer, dek1, TestDoc.PathsToEncrypt, partitionKey); + TestDoc doc1ToReplace = doc1ToReplaceCreateResponse.Resource; + doc1ToReplace.NonSensitive = Guid.NewGuid().ToString(); + doc1ToReplace.Sensitive = Guid.NewGuid().ToString(); + + TransactionalBatchResponse batchResponse = await encryptionContainerWithCustomSerializer.CreateTransactionalBatch(new Cosmos.PartitionKey(partitionKey)) + .CreateItem(doc1ToCreate, EncryptionTests.GetBatchItemRequestOptions(dek1, TestDoc.PathsToEncrypt)) + .ReplaceItem(doc1ToReplace.Id, doc1ToReplace, EncryptionTests.GetBatchItemRequestOptions(dek1, TestDoc.PathsToEncrypt, doc1ToReplaceCreateResponse.ETag)) + .ExecuteAsync(); + + Assert.AreEqual(HttpStatusCode.OK, batchResponse.StatusCode); + // FromStream is called as part of CreateItem request + Assert.AreEqual(1, customSerializer.FromStreamCalled); + + TransactionalBatchOperationResult doc1 = batchResponse.GetOperationResultAtIndex(0); + Assert.AreEqual(doc1ToCreate, doc1.Resource); + Assert.AreEqual(2, customSerializer.FromStreamCalled); + + TransactionalBatchOperationResult doc2 = batchResponse.GetOperationResultAtIndex(1); + Assert.AreEqual(doc1ToReplace, doc2.Resource); + Assert.AreEqual(3, customSerializer.FromStreamCalled); + + await EncryptionTests.VerifyItemByReadAsync(encryptionContainerWithCustomSerializer, doc1ToCreate); + await EncryptionTests.VerifyItemByReadAsync(encryptionContainerWithCustomSerializer, doc1ToReplace); + + // Validate that the documents are encrypted as expected by trying to retrieve through regular (non-encryption) container + doc1ToCreate.Sensitive = null; + await EncryptionTests.VerifyItemByReadAsync(EncryptionTests.itemContainer, doc1ToCreate); + + doc1ToReplace.Sensitive = null; + await EncryptionTests.VerifyItemByReadAsync(EncryptionTests.itemContainer, doc1ToReplace); + } + private static async Task ValidateSprocResultsAsync(Container container, TestDoc expectedDoc) { string sprocId = Guid.NewGuid().ToString(); @@ -1253,23 +1332,27 @@ private async Task ValidateChangeFeedIteratorResponse( TestDoc testDoc2) { FeedIterator changeIterator = container.GetChangeFeedIterator( - continuationToken: null, - new ChangeFeedRequestOptions() - { - StartTime = DateTime.MinValue.ToUniversalTime() - }); + ChangeFeedStartFrom.Beginning()); List changeFeedReturnedDocs = new List(); while (changeIterator.HasMoreResults) { - FeedResponse testDocs = await changeIterator.ReadNextAsync(); - for (int index = 0; index < testDocs.Count; index++) + try { - if (testDocs.Resource.ElementAt(index).Id.Equals(testDoc1.Id) || testDocs.Resource.ElementAt(index).Id.Equals(testDoc2.Id)) + FeedResponse testDocs = await changeIterator.ReadNextAsync(); + for (int index = 0; index < testDocs.Count; index++) { - changeFeedReturnedDocs.Add(testDocs.Resource.ElementAt(index)); + if (testDocs.Resource.ElementAt(index).Id.Equals(testDoc1.Id) || testDocs.Resource.ElementAt(index).Id.Equals(testDoc2.Id)) + { + changeFeedReturnedDocs.Add(testDocs.Resource.ElementAt(index)); + } } } + catch (CosmosException ex) + { + Assert.IsTrue(ex.Message.Contains("Response status code does not indicate success: NotModified (304)")); + break; + } } Assert.AreEqual(changeFeedReturnedDocs.Count, 2); @@ -1810,5 +1893,40 @@ public override async ValueTask GetTokenCredentialAsync(Uri key return await Task.FromResult(new DefaultAzureCredential()); } } + + internal class CustomSerializer : CosmosSerializer + { + private readonly JsonSerializer serializer = new JsonSerializer(); + public int FromStreamCalled = 0; + + public override T FromStream(Stream stream) + { + this.FromStreamCalled++; + using (StreamReader sr = new StreamReader(stream)) + using (JsonReader reader = new JsonTextReader(sr)) + { + JsonSerializer serializer = new JsonSerializer(); + return this.serializer.Deserialize(reader); + } + } + + public override Stream ToStream(T input) + { + MemoryStream streamPayload = new MemoryStream(); + using (StreamWriter streamWriter = new StreamWriter(streamPayload, encoding: UTF8Encoding.UTF8, bufferSize: 1024, leaveOpen: true)) + { + using (JsonWriter writer = new JsonTextWriter(streamWriter)) + { + writer.Formatting = Newtonsoft.Json.Formatting.None; + this.serializer.Serialize(writer, input); + writer.Flush(); + streamWriter.Flush(); + } + } + + streamPayload.Position = 0; + return streamPayload; + } + } } }