Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make concurrent reads on HttpHeaders thread-safe #68115

Merged
merged 1 commit into from
May 6, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
using System.Text;
using System.Threading;

namespace System.Net.Http.Headers
{
Expand Down Expand Up @@ -572,38 +573,41 @@ internal virtual void AddHeaders(HttpHeaders sourceHeaders)

private static HeaderStoreItemInfo CloneHeaderInfo(HeaderDescriptor descriptor, HeaderStoreItemInfo sourceInfo)
{
var destinationInfo = new HeaderStoreItemInfo
lock (sourceInfo)
{
// Always copy raw values
RawValue = CloneStringHeaderInfoValues(sourceInfo.RawValue)
};
var destinationInfo = new HeaderStoreItemInfo
{
// Always copy raw values
RawValue = CloneStringHeaderInfoValues(sourceInfo.RawValue)
};

if (descriptor.Parser == null)
{
sourceInfo.AssertContainsNoInvalidValues();
destinationInfo.ParsedAndInvalidValues = CloneStringHeaderInfoValues(sourceInfo.ParsedAndInvalidValues);
}
else
{
// We have a parser, so we also have to clone invalid values and parsed values.
if (sourceInfo.ParsedAndInvalidValues != null)
if (descriptor.Parser == null)
{
List<object>? sourceValues = sourceInfo.ParsedAndInvalidValues as List<object>;
if (sourceValues == null)
{
CloneAndAddValue(destinationInfo, sourceInfo.ParsedAndInvalidValues);
}
else
sourceInfo.AssertContainsNoInvalidValues();
destinationInfo.ParsedAndInvalidValues = CloneStringHeaderInfoValues(sourceInfo.ParsedAndInvalidValues);
}
else
{
// We have a parser, so we also have to clone invalid values and parsed values.
if (sourceInfo.ParsedAndInvalidValues != null)
{
foreach (object item in sourceValues)
List<object>? sourceValues = sourceInfo.ParsedAndInvalidValues as List<object>;
if (sourceValues == null)
{
CloneAndAddValue(destinationInfo, item);
CloneAndAddValue(destinationInfo, sourceInfo.ParsedAndInvalidValues);
}
else
{
foreach (object item in sourceValues)
{
CloneAndAddValue(destinationInfo, item);
}
}
}
}
}

return destinationInfo;
return destinationInfo;
}
}

private static void CloneAndAddValue(HeaderStoreItemInfo destinationInfo, object source)
Expand Down Expand Up @@ -713,31 +717,35 @@ private static void ParseRawHeaderValues(HeaderDescriptor descriptor, HeaderStor
{
// Unlike TryGetHeaderInfo() this method tries to parse all non-validated header values (if any)
// before returning to the caller.
Debug.Assert(!info.IsEmpty);
if (info.RawValue != null)
lock (info)
{
if (info.RawValue is List<string> rawValues)
Debug.Assert(!info.IsEmpty);
if (info.RawValue != null)
{
foreach (string rawValue in rawValues)
if (info.RawValue is List<string> rawValues)
{
foreach (string rawValue in rawValues)
{
ParseSingleRawHeaderValue(info, descriptor, rawValue);
}
}
else
{
string? rawValue = info.RawValue as string;
Debug.Assert(rawValue is not null);
ParseSingleRawHeaderValue(info, descriptor, rawValue);
}
}
else
{
string? rawValue = info.RawValue as string;
Debug.Assert(rawValue is not null);
ParseSingleRawHeaderValue(info, descriptor, rawValue);
}

// At this point all values are either in info.ParsedValue, info.InvalidValue. Reset RawValue.
Debug.Assert(info.ParsedAndInvalidValues is not null);
info.RawValue = null;
// At this point all values are either in info.ParsedValue, info.InvalidValue. Reset RawValue.
Debug.Assert(info.ParsedAndInvalidValues is not null);
info.RawValue = null;
}
}
}

private static void ParseSingleRawHeaderValue(HeaderStoreItemInfo info, HeaderDescriptor descriptor, string rawValue)
{
Debug.Assert(Monitor.IsEntered(info));
if (descriptor.Parser == null)
{
if (HttpRuleParser.ContainsNewLine(rawValue))
Expand Down Expand Up @@ -1095,26 +1103,29 @@ internal static void GetStoreValuesAsStringOrStringArray(HeaderDescriptor descri
return;
}

int length = GetValueCount(info);

Span<string?> values;
singleValue = null;
if (length == 1)
lock (info)
{
multiValue = null;
values = MemoryMarshal.CreateSpan(ref singleValue, 1);
}
else
{
Debug.Assert(length > 1, "The header should have been removed when it became empty");
values = multiValue = new string[length];
}
int length = GetValueCount(info);

int currentIndex = 0;
ReadStoreValues<object?>(values, info.ParsedAndInvalidValues, descriptor.Parser, ref currentIndex);
ReadStoreValues<string?>(values, info.RawValue, null, ref currentIndex);
Span<string?> values;
singleValue = null;
if (length == 1)
{
multiValue = null;
values = MemoryMarshal.CreateSpan(ref singleValue, 1);
}
else
{
Debug.Assert(length > 1, "The header should have been removed when it became empty");
values = multiValue = new string[length];
}

int currentIndex = 0;
ReadStoreValues<object?>(values, info.ParsedAndInvalidValues, descriptor.Parser, ref currentIndex);
ReadStoreValues<string?>(values, info.RawValue, null, ref currentIndex);

Debug.Assert(currentIndex == length);
Debug.Assert(currentIndex == length);
}
}

internal static int GetStoreValuesIntoStringArray(HeaderDescriptor descriptor, object sourceValues, [NotNull] ref string[]? values)
Expand All @@ -1135,10 +1146,11 @@ internal static int GetStoreValuesIntoStringArray(HeaderDescriptor descriptor, o
return 1;
}

int length = GetValueCount(info);

if (length > 0)
lock (info)
{
int length = GetValueCount(info);
Debug.Assert(length > 0);
MihaZupan marked this conversation as resolved.
Show resolved Hide resolved

if (values.Length < length)
{
values = new string[length];
Expand All @@ -1148,14 +1160,15 @@ internal static int GetStoreValuesIntoStringArray(HeaderDescriptor descriptor, o
ReadStoreValues<object?>(values, info.ParsedAndInvalidValues, descriptor.Parser, ref currentIndex);
ReadStoreValues<string?>(values, info.RawValue, null, ref currentIndex);
Debug.Assert(currentIndex == length);
}

return length;
return length;
}
}

private static int GetValueCount(HeaderStoreItemInfo info)
{
Debug.Assert(info != null);
Debug.Assert(Monitor.IsEntered(info));

return Count<object>(info.ParsedAndInvalidValues) + Count<string>(info.RawValue);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
using System.Linq;
using System.Net.Http.Headers;
using System.Tests;

using System.Threading;
using System.Threading.Tasks;
using Xunit;

namespace System.Net.Http.Tests
Expand Down Expand Up @@ -2374,6 +2375,99 @@ public void Add_LargeNumberOfHeaders_OperationsStillSupported(int numberOfHeader
Assert.Equal(new[] { "newValue" }, valuesFor3);
}

[Fact]
MihaZupan marked this conversation as resolved.
Show resolved Hide resolved
public async Task ConcurrentReads_AreThreadSafe()
{
if (Environment.ProcessorCount < 3) return;

const int TestRunTimeMs = 100;

bool running = true;
HttpRequestHeaders headers = CreateHeaders();

Task readerTask1 = Task.Run(ReaderWorker);
Task readerTask2 = Task.Run(ReaderWorker);
Task writerTask = Task.Run(() =>
{
while (Volatile.Read(ref running)) Volatile.Write(ref headers, CreateHeaders());
});

await Task.Delay(TimeSpan.FromMilliseconds(TestRunTimeMs));

Volatile.Write(ref running, false);

await writerTask;
await readerTask1;
await readerTask2;

void ReaderWorker()
{
var tests = new Action<HttpRequestHeaders>[0];
tests = new Action<HttpRequestHeaders>[]
{
static headers => Assert.Equal(6, headers.NonValidated.Count),
static headers => Assert.True(headers.Contains("a")),
static headers => Assert.False(headers.Contains("b")),
static headers => Assert.True(headers.NonValidated.Contains("a")),
static headers => Assert.False(headers.NonValidated.Contains("b")),
static headers => Assert.True(headers.Contains("c")),
static headers => Assert.True(headers.Contains("f")),
static headers => Assert.True(headers.Contains("h")),
static headers => Assert.True(headers.Contains("H")),
static headers => Assert.True(headers.Contains("Accept")),
static headers => Assert.True(headers.Contains("Cache-Control")),
static headers => Assert.True(headers.TryGetValues("a", out var values) && values.Single() == "b"),
static headers => Assert.False(headers.TryGetValues("b", out _)),
static headers => Assert.True(headers.NonValidated.TryGetValues("a", out var values) && values.Single() == "b"),
static headers => Assert.False(headers.NonValidated.TryGetValues("b", out _)),
static headers => Assert.True(headers.TryGetValues("f", out var values) && values.Single() == "g"),
static headers => Assert.True(headers.NonValidated.TryGetValues("f", out var values) && values.Single() == "g"),
static headers => Assert.True(headers.TryGetValues("c", out var values) && values.Count() == 2 && values.First() == "d" && values.Last() == "e"),
static headers => Assert.True(headers.NonValidated.TryGetValues("c", out var values) && values.Count() == 2 && values.First() == "d" && values.Last() == "e"),
static headers => Assert.True(headers.TryGetValues("h", out var values) && values.Count() == 2 && values.First() == "i" && values.Last() == "j"),
static headers => Assert.True(headers.NonValidated.TryGetValues("h", out var values) && values.Count() == 2 && values.First() == "i" && values.Last() == "j"),
static headers => Assert.Equal("only-if-cached, private", headers.NonValidated["Cache-Control"].ToString()),
static headers => Assert.Equal("text/json", headers.Accept.Single().MediaType),
headers =>
{
var newHeaders = new HttpRequestMessage().Headers;
newHeaders.AddHeaders(headers);
for (int i = 0; i < tests.Length - 1; i++)
{
tests[i](newHeaders);
}
},
};

while (Volatile.Read(ref running))
{
tests[Random.Shared.Next(tests.Length)](Volatile.Read(ref headers));
}
}

static HttpRequestHeaders CreateHeaders()
{
HttpRequestHeaders headers = new HttpRequestMessage().Headers;

var actions = new Action<HttpRequestHeaders>[]
{
static headers => headers.Accept.Add(new MediaTypeWithQualityHeaderValue("text/json")),
static headers => headers.CacheControl = new CacheControlHeaderValue { Private = true, OnlyIfCached = true },
static headers => headers.Add("a", "b"),
static headers => headers.Add("c", new[] { "d", "e" }),
static headers => headers.TryAddWithoutValidation("f", "g"),
static headers => headers.TryAddWithoutValidation("h", new[] { "i", "j" }),
};

foreach (Action<HttpRequestHeaders> action in actions.OrderBy(_ => Random.Shared.Next()))
{
action(headers);
}

return headers;
}
}

[Fact]
public void TryAddInvalidHeader_ShouldThrowFormatException()
{
Expand Down