Skip to content

Commit

Permalink
Make AsyncState thread safe (#4833)
Browse files Browse the repository at this point in the history
* Replace List with ConcurrentDictionary in AsyncState

* Remove unused namespace in FeaturesPooledPolicy.cs

* Fixes in tests

* Replace ConcurrentDictionary with custom Features class

* Remove unused namespace in AsyncState.cs

---------

Co-authored-by: Martin Obratil <[email protected]>
  • Loading branch information
mobratil and Martin Obratil committed Dec 22, 2023
1 parent 9f5be48 commit 78a1df9
Show file tree
Hide file tree
Showing 5 changed files with 66 additions and 67 deletions.
31 changes: 7 additions & 24 deletions src/Libraries/Microsoft.Extensions.AsyncState/AsyncState.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
// The .NET Foundation licenses this file to you under the MIT license.

using System;
using System.Collections.Generic;
using System.Threading;
using Microsoft.Extensions.ObjectPool;
using Microsoft.Shared.Diagnostics;
Expand All @@ -13,7 +12,7 @@ namespace Microsoft.Extensions.AsyncState;
internal sealed class AsyncState : IAsyncState
{
private static readonly AsyncLocal<AsyncStateHolder> _asyncContextCurrent = new();
private static readonly ObjectPool<List<object?>> _featuresPool = PoolFactory.CreatePool(new FeaturesPooledPolicy());
private static readonly ObjectPool<Features> _featuresPool = PoolFactory.CreatePool(new FeaturesPooledPolicy());
private int _contextCount;

public void Initialize()
Expand All @@ -22,12 +21,12 @@ public void Initialize()

// Use an object indirection to hold the AsyncContext in the AsyncLocal,
// so it can be cleared in all ExecutionContexts when its cleared.
var features = new AsyncStateHolder
var asyncStateHolder = new AsyncStateHolder
{
Features = _featuresPool.Get()
};

_asyncContextCurrent.Value = features;
_asyncContextCurrent.Value = asyncStateHolder;
}

public void Reset()
Expand Down Expand Up @@ -60,9 +59,7 @@ public bool TryGet(AsyncStateToken token, out object? value)
return false;
}

EnsureCount(_asyncContextCurrent.Value.Features, token.Index + 1);

value = _asyncContextCurrent.Value.Features[token.Index];
value = _asyncContextCurrent.Value.Features.Get(token.Index);
return true;
}

Expand All @@ -86,28 +83,14 @@ public void Set(AsyncStateToken token, object? value)
Throw.InvalidOperationException("Context is not initialized");
}

EnsureCount(_asyncContextCurrent.Value.Features, token.Index + 1);

_asyncContextCurrent.Value.Features[token.Index] = value;
}

internal static void EnsureCount(List<object?> features, int count)
{
#if NET6_0_OR_GREATER
features.EnsureCapacity(count);
#endif
var difference = count - features.Count;

for (int i = 0; i < difference; i++)
{
features.Add(null);
}
_asyncContextCurrent.Value.Features.Set(token.Index, value);
}

internal int ContextCount => Volatile.Read(ref _contextCount);

private sealed class AsyncStateHolder
{
public List<object?>? Features { get; set; }
public Features? Features { get; set; }
}

}
45 changes: 45 additions & 0 deletions src/Libraries/Microsoft.Extensions.AsyncState/Features.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System.Collections.Generic;

namespace Microsoft.Extensions.AsyncState;

internal sealed class Features
{
private readonly List<object?> _items = [];

public object? Get(int index)
{
return _items.Count <= index ? null : _items[index];
}

public void Set(int index, object? value)
{
if (_items.Count <= index)
{
lock (_items)
{
var count = index + 1;

#if NET6_0_OR_GREATER
_items.EnsureCapacity(count);
#endif

var difference = count - _items.Count;

for (int i = 0; i < difference; i++)
{
_items.Add(null);
}
}
}

_items[index] = value;
}

public void Clear()
{
_items.Clear();
}
}
Original file line number Diff line number Diff line change
@@ -1,27 +1,22 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System.Collections.Generic;
using Microsoft.Extensions.ObjectPool;

namespace Microsoft.Extensions.AsyncState;

internal sealed class FeaturesPooledPolicy : IPooledObjectPolicy<List<object?>>
internal sealed class FeaturesPooledPolicy : IPooledObjectPolicy<Features>
{
/// <inheritdoc/>
public List<object?> Create()
public Features Create()
{
return [];
return new Features();
}

/// <inheritdoc/>
public bool Return(List<object?> obj)
public bool Return(Features obj)
{
for (int i = 0; i < obj.Count; i++)
{
obj[i] = null;
}

obj.Clear();
return true;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
// The .NET Foundation licenses this file to you under the MIT license.

using System;
using System.Collections.Generic;
using System.Diagnostics.CodeAnalysis;
using System.Threading;
using System.Threading.Tasks;
Expand Down Expand Up @@ -205,28 +204,4 @@ public void RegisterContextCorrectly()

Assert.Equal(3, asyncState.ContextCount);
}

[Fact]
public void EnsureCount_IncreasesCountCorrectly()
{
var l = new List<object?>();
AsyncState.EnsureCount(l, 5);
Assert.Equal(5, l.Count);
}

[Fact]
public void EnsureCount_WhenCountLessThanExpected()
{
var l = new List<object?>(new object?[5]);
AsyncState.EnsureCount(l, 2);
Assert.Equal(5, l.Count);
}

[Fact]
public void EnsureCount_WhenCountEqualWithExpected()
{
var l = new List<object?>(new object?[5]);
AsyncState.EnsureCount(l, 5);
Assert.Equal(5, l.Count);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
// The .NET Foundation licenses this file to you under the MIT license.

using System;
using System.Collections.Generic;
using Xunit;

namespace Microsoft.Extensions.AsyncState.Test;
Expand All @@ -13,20 +12,22 @@ public void Return_ShouldBeTrue()
{
var policy = new FeaturesPooledPolicy();

Assert.True(policy.Return([]));
Assert.True(policy.Return(new Features()));
}

[Fact]
public void Return_ShouldNullList()
{
var policy = new FeaturesPooledPolicy();

var list = policy.Create();
list.Add(string.Empty);
list.Add(Array.Empty<int>());
list.Add(new object());
var features = policy.Create();
features.Set(0, string.Empty);
features.Set(1, Array.Empty<int>());
features.Set(2, new object());

Assert.True(policy.Return(list));
Assert.All(list, el => Assert.Null(el));
Assert.True(policy.Return(features));
Assert.Null(features.Get(0));
Assert.Null(features.Get(1));
Assert.Null(features.Get(2));
}
}

0 comments on commit 78a1df9

Please sign in to comment.