Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make concurrent reads on HttpHeaders thread-safe #68115

Merged
merged 1 commit into from
May 6, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Original file line Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
using System.Runtime.CompilerServices; using System.Runtime.CompilerServices;
using System.Runtime.InteropServices; using System.Runtime.InteropServices;
using System.Text; using System.Text;
using System.Threading;


namespace System.Net.Http.Headers namespace System.Net.Http.Headers
{ {
Expand Down Expand Up @@ -571,6 +572,8 @@ internal virtual void AddHeaders(HttpHeaders sourceHeaders)
} }


private static HeaderStoreItemInfo CloneHeaderInfo(HeaderDescriptor descriptor, HeaderStoreItemInfo sourceInfo) private static HeaderStoreItemInfo CloneHeaderInfo(HeaderDescriptor descriptor, HeaderStoreItemInfo sourceInfo)
{
lock (sourceInfo)
{ {
var destinationInfo = new HeaderStoreItemInfo var destinationInfo = new HeaderStoreItemInfo
{ {
Expand Down Expand Up @@ -605,6 +608,7 @@ private static HeaderStoreItemInfo CloneHeaderInfo(HeaderDescriptor descriptor,


return destinationInfo; return destinationInfo;
} }
}


private static void CloneAndAddValue(HeaderStoreItemInfo destinationInfo, object source) private static void CloneAndAddValue(HeaderStoreItemInfo destinationInfo, object source)
{ {
Expand Down Expand Up @@ -713,6 +717,8 @@ private static void ParseRawHeaderValues(HeaderDescriptor descriptor, HeaderStor
{ {
// Unlike TryGetHeaderInfo() this method tries to parse all non-validated header values (if any) // Unlike TryGetHeaderInfo() this method tries to parse all non-validated header values (if any)
// before returning to the caller. // before returning to the caller.
lock (info)
{
Debug.Assert(!info.IsEmpty); Debug.Assert(!info.IsEmpty);
if (info.RawValue != null) if (info.RawValue != null)
{ {
Expand All @@ -735,9 +741,11 @@ private static void ParseRawHeaderValues(HeaderDescriptor descriptor, HeaderStor
info.RawValue = null; info.RawValue = null;
} }
} }
}


private static void ParseSingleRawHeaderValue(HeaderStoreItemInfo info, HeaderDescriptor descriptor, string rawValue) private static void ParseSingleRawHeaderValue(HeaderStoreItemInfo info, HeaderDescriptor descriptor, string rawValue)
{ {
Debug.Assert(Monitor.IsEntered(info));
if (descriptor.Parser == null) if (descriptor.Parser == null)
{ {
if (HttpRuleParser.ContainsNewLine(rawValue)) if (HttpRuleParser.ContainsNewLine(rawValue))
Expand Down Expand Up @@ -1095,6 +1103,8 @@ internal static void GetStoreValuesAsStringOrStringArray(HeaderDescriptor descri
return; return;
} }


lock (info)
{
int length = GetValueCount(info); int length = GetValueCount(info);


Span<string?> values; Span<string?> values;
Expand All @@ -1116,6 +1126,7 @@ internal static void GetStoreValuesAsStringOrStringArray(HeaderDescriptor descri


Debug.Assert(currentIndex == length); Debug.Assert(currentIndex == length);
} }
}


internal static int GetStoreValuesIntoStringArray(HeaderDescriptor descriptor, object sourceValues, [NotNull] ref string[]? values) internal static int GetStoreValuesIntoStringArray(HeaderDescriptor descriptor, object sourceValues, [NotNull] ref string[]? values)
{ {
Expand All @@ -1135,10 +1146,11 @@ internal static int GetStoreValuesIntoStringArray(HeaderDescriptor descriptor, o
return 1; return 1;
} }


lock (info)
{
int length = GetValueCount(info); int length = GetValueCount(info);
Debug.Assert(length > 0);
MihaZupan marked this conversation as resolved.
Show resolved Hide resolved


if (length > 0)
{
if (values.Length < length) if (values.Length < length)
{ {
values = new string[length]; values = new string[length];
Expand All @@ -1148,14 +1160,15 @@ internal static int GetStoreValuesIntoStringArray(HeaderDescriptor descriptor, o
ReadStoreValues<object?>(values, info.ParsedAndInvalidValues, descriptor.Parser, ref currentIndex); ReadStoreValues<object?>(values, info.ParsedAndInvalidValues, descriptor.Parser, ref currentIndex);
ReadStoreValues<string?>(values, info.RawValue, null, ref currentIndex); ReadStoreValues<string?>(values, info.RawValue, null, ref currentIndex);
Debug.Assert(currentIndex == length); Debug.Assert(currentIndex == length);
}


return length; return length;
} }
}


private static int GetValueCount(HeaderStoreItemInfo info) private static int GetValueCount(HeaderStoreItemInfo info)
{ {
Debug.Assert(info != null); Debug.Assert(info != null);
Debug.Assert(Monitor.IsEntered(info));


return Count<object>(info.ParsedAndInvalidValues) + Count<string>(info.RawValue); return Count<object>(info.ParsedAndInvalidValues) + Count<string>(info.RawValue);


Expand Down
Original file line number Original file line Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
using System.Linq; using System.Linq;
using System.Net.Http.Headers; using System.Net.Http.Headers;
using System.Tests; using System.Tests;

using System.Threading;
using System.Threading.Tasks;
using Xunit; using Xunit;


namespace System.Net.Http.Tests namespace System.Net.Http.Tests
Expand Down Expand Up @@ -2374,6 +2375,99 @@ public void Add_LargeNumberOfHeaders_OperationsStillSupported(int numberOfHeader
Assert.Equal(new[] { "newValue" }, valuesFor3); Assert.Equal(new[] { "newValue" }, valuesFor3);
} }


[Fact]
MihaZupan marked this conversation as resolved.
Show resolved Hide resolved
public async Task ConcurrentReads_AreThreadSafe()
{
if (Environment.ProcessorCount < 3) return;

const int TestRunTimeMs = 100;

bool running = true;
HttpRequestHeaders headers = CreateHeaders();

Task readerTask1 = Task.Run(ReaderWorker);
Task readerTask2 = Task.Run(ReaderWorker);
Task writerTask = Task.Run(() =>
{
while (Volatile.Read(ref running)) Volatile.Write(ref headers, CreateHeaders());
});

await Task.Delay(TimeSpan.FromMilliseconds(TestRunTimeMs));

Volatile.Write(ref running, false);

await writerTask;
await readerTask1;
await readerTask2;

void ReaderWorker()
{
var tests = new Action<HttpRequestHeaders>[0];
tests = new Action<HttpRequestHeaders>[]
{
static headers => Assert.Equal(6, headers.NonValidated.Count),
static headers => Assert.True(headers.Contains("a")),
static headers => Assert.False(headers.Contains("b")),
static headers => Assert.True(headers.NonValidated.Contains("a")),
static headers => Assert.False(headers.NonValidated.Contains("b")),
static headers => Assert.True(headers.Contains("c")),
static headers => Assert.True(headers.Contains("f")),
static headers => Assert.True(headers.Contains("h")),
static headers => Assert.True(headers.Contains("H")),
static headers => Assert.True(headers.Contains("Accept")),
static headers => Assert.True(headers.Contains("Cache-Control")),
static headers => Assert.True(headers.TryGetValues("a", out var values) && values.Single() == "b"),
static headers => Assert.False(headers.TryGetValues("b", out _)),
static headers => Assert.True(headers.NonValidated.TryGetValues("a", out var values) && values.Single() == "b"),
static headers => Assert.False(headers.NonValidated.TryGetValues("b", out _)),
static headers => Assert.True(headers.TryGetValues("f", out var values) && values.Single() == "g"),
static headers => Assert.True(headers.NonValidated.TryGetValues("f", out var values) && values.Single() == "g"),
static headers => Assert.True(headers.TryGetValues("c", out var values) && values.Count() == 2 && values.First() == "d" && values.Last() == "e"),
static headers => Assert.True(headers.NonValidated.TryGetValues("c", out var values) && values.Count() == 2 && values.First() == "d" && values.Last() == "e"),
static headers => Assert.True(headers.TryGetValues("h", out var values) && values.Count() == 2 && values.First() == "i" && values.Last() == "j"),
static headers => Assert.True(headers.NonValidated.TryGetValues("h", out var values) && values.Count() == 2 && values.First() == "i" && values.Last() == "j"),
static headers => Assert.Equal("only-if-cached, private", headers.NonValidated["Cache-Control"].ToString()),
static headers => Assert.Equal("text/json", headers.Accept.Single().MediaType),
headers =>
{
var newHeaders = new HttpRequestMessage().Headers;
newHeaders.AddHeaders(headers);
for (int i = 0; i < tests.Length - 1; i++)
{
tests[i](newHeaders);
}
},
};

while (Volatile.Read(ref running))
{
tests[Random.Shared.Next(tests.Length)](Volatile.Read(ref headers));
}
}

static HttpRequestHeaders CreateHeaders()
{
HttpRequestHeaders headers = new HttpRequestMessage().Headers;

var actions = new Action<HttpRequestHeaders>[]
{
static headers => headers.Accept.Add(new MediaTypeWithQualityHeaderValue("text/json")),
static headers => headers.CacheControl = new CacheControlHeaderValue { Private = true, OnlyIfCached = true },
static headers => headers.Add("a", "b"),
static headers => headers.Add("c", new[] { "d", "e" }),
static headers => headers.TryAddWithoutValidation("f", "g"),
static headers => headers.TryAddWithoutValidation("h", new[] { "i", "j" }),
};

foreach (Action<HttpRequestHeaders> action in actions.OrderBy(_ => Random.Shared.Next()))
{
action(headers);
}

return headers;
}
}

[Fact] [Fact]
public void TryAddInvalidHeader_ShouldThrowFormatException() public void TryAddInvalidHeader_ShouldThrowFormatException()
{ {
Expand Down