diff --git a/src/libraries/System.Net.Http/src/System/Net/Http/Headers/HttpHeaders.cs b/src/libraries/System.Net.Http/src/System/Net/Http/Headers/HttpHeaders.cs index 1510bb42496d9..75cc5bb3a4aa5 100644 --- a/src/libraries/System.Net.Http/src/System/Net/Http/Headers/HttpHeaders.cs +++ b/src/libraries/System.Net.Http/src/System/Net/Http/Headers/HttpHeaders.cs @@ -8,6 +8,7 @@ using System.Runtime.CompilerServices; using System.Runtime.InteropServices; using System.Text; +using System.Threading; namespace System.Net.Http.Headers { @@ -572,38 +573,41 @@ internal virtual void AddHeaders(HttpHeaders sourceHeaders) private static HeaderStoreItemInfo CloneHeaderInfo(HeaderDescriptor descriptor, HeaderStoreItemInfo sourceInfo) { - var destinationInfo = new HeaderStoreItemInfo + lock (sourceInfo) { - // Always copy raw values - RawValue = CloneStringHeaderInfoValues(sourceInfo.RawValue) - }; + var destinationInfo = new HeaderStoreItemInfo + { + // Always copy raw values + RawValue = CloneStringHeaderInfoValues(sourceInfo.RawValue) + }; - if (descriptor.Parser == null) - { - sourceInfo.AssertContainsNoInvalidValues(); - destinationInfo.ParsedAndInvalidValues = CloneStringHeaderInfoValues(sourceInfo.ParsedAndInvalidValues); - } - else - { - // We have a parser, so we also have to clone invalid values and parsed values. - if (sourceInfo.ParsedAndInvalidValues != null) + if (descriptor.Parser == null) { - List? sourceValues = sourceInfo.ParsedAndInvalidValues as List; - if (sourceValues == null) - { - CloneAndAddValue(destinationInfo, sourceInfo.ParsedAndInvalidValues); - } - else + sourceInfo.AssertContainsNoInvalidValues(); + destinationInfo.ParsedAndInvalidValues = CloneStringHeaderInfoValues(sourceInfo.ParsedAndInvalidValues); + } + else + { + // We have a parser, so we also have to clone invalid values and parsed values. + if (sourceInfo.ParsedAndInvalidValues != null) { - foreach (object item in sourceValues) + List? sourceValues = sourceInfo.ParsedAndInvalidValues as List; + if (sourceValues == null) { - CloneAndAddValue(destinationInfo, item); + CloneAndAddValue(destinationInfo, sourceInfo.ParsedAndInvalidValues); + } + else + { + foreach (object item in sourceValues) + { + CloneAndAddValue(destinationInfo, item); + } } } } - } - return destinationInfo; + return destinationInfo; + } } private static void CloneAndAddValue(HeaderStoreItemInfo destinationInfo, object source) @@ -713,31 +717,35 @@ private static void ParseRawHeaderValues(HeaderDescriptor descriptor, HeaderStor { // Unlike TryGetHeaderInfo() this method tries to parse all non-validated header values (if any) // before returning to the caller. - Debug.Assert(!info.IsEmpty); - if (info.RawValue != null) + lock (info) { - if (info.RawValue is List rawValues) + Debug.Assert(!info.IsEmpty); + if (info.RawValue != null) { - foreach (string rawValue in rawValues) + if (info.RawValue is List rawValues) { + foreach (string rawValue in rawValues) + { + ParseSingleRawHeaderValue(info, descriptor, rawValue); + } + } + else + { + string? rawValue = info.RawValue as string; + Debug.Assert(rawValue is not null); ParseSingleRawHeaderValue(info, descriptor, rawValue); } - } - else - { - string? rawValue = info.RawValue as string; - Debug.Assert(rawValue is not null); - ParseSingleRawHeaderValue(info, descriptor, rawValue); - } - // At this point all values are either in info.ParsedValue, info.InvalidValue. Reset RawValue. - Debug.Assert(info.ParsedAndInvalidValues is not null); - info.RawValue = null; + // At this point all values are either in info.ParsedValue, info.InvalidValue. Reset RawValue. + Debug.Assert(info.ParsedAndInvalidValues is not null); + info.RawValue = null; + } } } private static void ParseSingleRawHeaderValue(HeaderStoreItemInfo info, HeaderDescriptor descriptor, string rawValue) { + Debug.Assert(Monitor.IsEntered(info)); if (descriptor.Parser == null) { if (HttpRuleParser.ContainsNewLine(rawValue)) @@ -1095,26 +1103,29 @@ internal static void GetStoreValuesAsStringOrStringArray(HeaderDescriptor descri return; } - int length = GetValueCount(info); - - Span values; - singleValue = null; - if (length == 1) + lock (info) { - multiValue = null; - values = MemoryMarshal.CreateSpan(ref singleValue, 1); - } - else - { - Debug.Assert(length > 1, "The header should have been removed when it became empty"); - values = multiValue = new string[length]; - } + int length = GetValueCount(info); - int currentIndex = 0; - ReadStoreValues(values, info.ParsedAndInvalidValues, descriptor.Parser, ref currentIndex); - ReadStoreValues(values, info.RawValue, null, ref currentIndex); + Span values; + singleValue = null; + if (length == 1) + { + multiValue = null; + values = MemoryMarshal.CreateSpan(ref singleValue, 1); + } + else + { + Debug.Assert(length > 1, "The header should have been removed when it became empty"); + values = multiValue = new string[length]; + } + + int currentIndex = 0; + ReadStoreValues(values, info.ParsedAndInvalidValues, descriptor.Parser, ref currentIndex); + ReadStoreValues(values, info.RawValue, null, ref currentIndex); - Debug.Assert(currentIndex == length); + Debug.Assert(currentIndex == length); + } } internal static int GetStoreValuesIntoStringArray(HeaderDescriptor descriptor, object sourceValues, [NotNull] ref string[]? values) @@ -1135,10 +1146,11 @@ internal static int GetStoreValuesIntoStringArray(HeaderDescriptor descriptor, o return 1; } - int length = GetValueCount(info); - - if (length > 0) + lock (info) { + int length = GetValueCount(info); + Debug.Assert(length > 0); + if (values.Length < length) { values = new string[length]; @@ -1148,14 +1160,15 @@ internal static int GetStoreValuesIntoStringArray(HeaderDescriptor descriptor, o ReadStoreValues(values, info.ParsedAndInvalidValues, descriptor.Parser, ref currentIndex); ReadStoreValues(values, info.RawValue, null, ref currentIndex); Debug.Assert(currentIndex == length); - } - return length; + return length; + } } private static int GetValueCount(HeaderStoreItemInfo info) { Debug.Assert(info != null); + Debug.Assert(Monitor.IsEntered(info)); return Count(info.ParsedAndInvalidValues) + Count(info.RawValue); diff --git a/src/libraries/System.Net.Http/tests/UnitTests/Headers/HttpHeadersTest.cs b/src/libraries/System.Net.Http/tests/UnitTests/Headers/HttpHeadersTest.cs index 8cc0341eab8f3..188a775761661 100644 --- a/src/libraries/System.Net.Http/tests/UnitTests/Headers/HttpHeadersTest.cs +++ b/src/libraries/System.Net.Http/tests/UnitTests/Headers/HttpHeadersTest.cs @@ -7,7 +7,8 @@ using System.Linq; using System.Net.Http.Headers; using System.Tests; - +using System.Threading; +using System.Threading.Tasks; using Xunit; namespace System.Net.Http.Tests @@ -2374,6 +2375,99 @@ public void Add_LargeNumberOfHeaders_OperationsStillSupported(int numberOfHeader Assert.Equal(new[] { "newValue" }, valuesFor3); } + [Fact] + public async Task ConcurrentReads_AreThreadSafe() + { + if (Environment.ProcessorCount < 3) return; + + const int TestRunTimeMs = 100; + + bool running = true; + HttpRequestHeaders headers = CreateHeaders(); + + Task readerTask1 = Task.Run(ReaderWorker); + Task readerTask2 = Task.Run(ReaderWorker); + Task writerTask = Task.Run(() => + { + while (Volatile.Read(ref running)) Volatile.Write(ref headers, CreateHeaders()); + }); + + await Task.Delay(TimeSpan.FromMilliseconds(TestRunTimeMs)); + + Volatile.Write(ref running, false); + + await writerTask; + await readerTask1; + await readerTask2; + + void ReaderWorker() + { + var tests = new Action[0]; + tests = new Action[] + { + static headers => Assert.Equal(6, headers.NonValidated.Count), + static headers => Assert.True(headers.Contains("a")), + static headers => Assert.False(headers.Contains("b")), + static headers => Assert.True(headers.NonValidated.Contains("a")), + static headers => Assert.False(headers.NonValidated.Contains("b")), + static headers => Assert.True(headers.Contains("c")), + static headers => Assert.True(headers.Contains("f")), + static headers => Assert.True(headers.Contains("h")), + static headers => Assert.True(headers.Contains("H")), + static headers => Assert.True(headers.Contains("Accept")), + static headers => Assert.True(headers.Contains("Cache-Control")), + static headers => Assert.True(headers.TryGetValues("a", out var values) && values.Single() == "b"), + static headers => Assert.False(headers.TryGetValues("b", out _)), + static headers => Assert.True(headers.NonValidated.TryGetValues("a", out var values) && values.Single() == "b"), + static headers => Assert.False(headers.NonValidated.TryGetValues("b", out _)), + static headers => Assert.True(headers.TryGetValues("f", out var values) && values.Single() == "g"), + static headers => Assert.True(headers.NonValidated.TryGetValues("f", out var values) && values.Single() == "g"), + static headers => Assert.True(headers.TryGetValues("c", out var values) && values.Count() == 2 && values.First() == "d" && values.Last() == "e"), + static headers => Assert.True(headers.NonValidated.TryGetValues("c", out var values) && values.Count() == 2 && values.First() == "d" && values.Last() == "e"), + static headers => Assert.True(headers.TryGetValues("h", out var values) && values.Count() == 2 && values.First() == "i" && values.Last() == "j"), + static headers => Assert.True(headers.NonValidated.TryGetValues("h", out var values) && values.Count() == 2 && values.First() == "i" && values.Last() == "j"), + static headers => Assert.Equal("only-if-cached, private", headers.NonValidated["Cache-Control"].ToString()), + static headers => Assert.Equal("text/json", headers.Accept.Single().MediaType), + headers => + { + var newHeaders = new HttpRequestMessage().Headers; + newHeaders.AddHeaders(headers); + for (int i = 0; i < tests.Length - 1; i++) + { + tests[i](newHeaders); + } + }, + }; + + while (Volatile.Read(ref running)) + { + tests[Random.Shared.Next(tests.Length)](Volatile.Read(ref headers)); + } + } + + static HttpRequestHeaders CreateHeaders() + { + HttpRequestHeaders headers = new HttpRequestMessage().Headers; + + var actions = new Action[] + { + static headers => headers.Accept.Add(new MediaTypeWithQualityHeaderValue("text/json")), + static headers => headers.CacheControl = new CacheControlHeaderValue { Private = true, OnlyIfCached = true }, + static headers => headers.Add("a", "b"), + static headers => headers.Add("c", new[] { "d", "e" }), + static headers => headers.TryAddWithoutValidation("f", "g"), + static headers => headers.TryAddWithoutValidation("h", new[] { "i", "j" }), + }; + + foreach (Action action in actions.OrderBy(_ => Random.Shared.Next())) + { + action(headers); + } + + return headers; + } + } + [Fact] public void TryAddInvalidHeader_ShouldThrowFormatException() {