diff --git a/src/libraries/System.Private.CoreLib/src/System/Reflection/NullabilityInfoContext.cs b/src/libraries/System.Private.CoreLib/src/System/Reflection/NullabilityInfoContext.cs index b58e03aeb6e44..1a87dc84ee96b 100644 --- a/src/libraries/System.Private.CoreLib/src/System/Reflection/NullabilityInfoContext.cs +++ b/src/libraries/System.Private.CoreLib/src/System/Reflection/NullabilityInfoContext.cs @@ -3,6 +3,7 @@ using System.Collections.Generic; using System.Collections.ObjectModel; +using System.Diagnostics; namespace System.Reflection { @@ -27,7 +28,7 @@ private enum NotAnnotatedStatus Internal = 0x2 // internal members not annotated } - private NullabilityState GetNullableContext(MemberInfo? memberInfo) + private NullabilityState? GetNullableContext(MemberInfo? memberInfo) { while (memberInfo != null) { @@ -51,7 +52,7 @@ private NullabilityState GetNullableContext(MemberInfo? memberInfo) memberInfo = memberInfo.DeclaringType; } - return NullabilityState.Unknown; + return null; } /// @@ -71,13 +72,11 @@ public NullabilityInfo Create(ParameterInfo parameterInfo) EnsureIsSupported(); - if (parameterInfo.Member is MethodInfo method && IsPrivateOrInternalMethodAndAnnotationDisabled(method)) - { - return new NullabilityInfo(parameterInfo.ParameterType, NullabilityState.Unknown, NullabilityState.Unknown, null, Array.Empty()); - } - IList attributes = parameterInfo.GetCustomAttributesData(); - NullabilityInfo nullability = GetNullabilityInfo(parameterInfo.Member, parameterInfo.ParameterType, attributes); + NullableAttributeStateParser parser = parameterInfo.Member is MethodBase method && IsPrivateOrInternalMethodAndAnnotationDisabled(method) + ? NullableAttributeStateParser.Unknown + : CreateParser(attributes); + NullabilityInfo nullability = GetNullabilityInfo(parameterInfo.Member, parameterInfo.ParameterType, parser); if (nullability.ReadState != NullabilityState.Unknown) { @@ -114,7 +113,7 @@ private void CheckParameterMetadataType(ParameterInfo parameter, NullabilityInfo if (metaParameter != null) { - CheckGenericParameters(nullability, metaMethod, metaParameter.ParameterType); + CheckGenericParameters(nullability, metaMethod, metaParameter.ParameterType, parameter.Member.ReflectedType); } } } @@ -131,40 +130,45 @@ private static MethodInfo GetMethodMetadataDefinition(MethodInfo method) private void CheckNullabilityAttributes(NullabilityInfo nullability, IList attributes) { + var codeAnalysisReadState = NullabilityState.Unknown; + var codeAnalysisWriteState = NullabilityState.Unknown; + foreach (CustomAttributeData attribute in attributes) { if (attribute.AttributeType.Namespace == "System.Diagnostics.CodeAnalysis") { - if (attribute.AttributeType.Name == "NotNullAttribute" && - nullability.ReadState == NullabilityState.Nullable) + if (attribute.AttributeType.Name == "NotNullAttribute") { - nullability.ReadState = NullabilityState.NotNull; - break; + codeAnalysisReadState = NullabilityState.NotNull; } else if ((attribute.AttributeType.Name == "MaybeNullAttribute" || attribute.AttributeType.Name == "MaybeNullWhenAttribute") && - nullability.ReadState == NullabilityState.NotNull && + codeAnalysisReadState == NullabilityState.Unknown && !nullability.Type.IsValueType) { - nullability.ReadState = NullabilityState.Nullable; - break; + codeAnalysisReadState = NullabilityState.Nullable; } - - if (attribute.AttributeType.Name == "DisallowNullAttribute" && - nullability.WriteState == NullabilityState.Nullable) + else if (attribute.AttributeType.Name == "DisallowNullAttribute") { - nullability.WriteState = NullabilityState.NotNull; - break; + codeAnalysisWriteState = NullabilityState.NotNull; } else if (attribute.AttributeType.Name == "AllowNullAttribute" && - nullability.WriteState == NullabilityState.NotNull && + codeAnalysisWriteState == NullabilityState.Unknown && !nullability.Type.IsValueType) { - nullability.WriteState = NullabilityState.Nullable; - break; + codeAnalysisWriteState = NullabilityState.Nullable; } } } + + if (codeAnalysisReadState != NullabilityState.Unknown) + { + nullability.ReadState = codeAnalysisReadState; + } + if (codeAnalysisWriteState != NullabilityState.Unknown) + { + nullability.WriteState = codeAnalysisWriteState; + } } /// @@ -184,17 +188,15 @@ public NullabilityInfo Create(PropertyInfo propertyInfo) EnsureIsSupported(); - NullabilityInfo nullability = GetNullabilityInfo(propertyInfo, propertyInfo.PropertyType, propertyInfo.GetCustomAttributesData()); MethodInfo? getter = propertyInfo.GetGetMethod(true); MethodInfo? setter = propertyInfo.GetSetMethod(true); + bool annotationsDisabled = (getter == null || IsPrivateOrInternalMethodAndAnnotationDisabled(getter)) + && (setter == null || IsPrivateOrInternalMethodAndAnnotationDisabled(setter)); + NullableAttributeStateParser parser = annotationsDisabled ? NullableAttributeStateParser.Unknown : CreateParser(propertyInfo.GetCustomAttributesData()); + NullabilityInfo nullability = GetNullabilityInfo(propertyInfo, propertyInfo.PropertyType, parser); if (getter != null) { - if (IsPrivateOrInternalMethodAndAnnotationDisabled(getter)) - { - nullability.ReadState = NullabilityState.Unknown; - } - CheckNullabilityAttributes(nullability, getter.ReturnParameter.GetCustomAttributesData()); } else @@ -204,12 +206,7 @@ public NullabilityInfo Create(PropertyInfo propertyInfo) if (setter != null) { - if (IsPrivateOrInternalMethodAndAnnotationDisabled(setter)) - { - nullability.WriteState = NullabilityState.Unknown; - } - - CheckNullabilityAttributes(nullability, setter.GetParameters()[0].GetCustomAttributesData()); + CheckNullabilityAttributes(nullability, setter.GetParameters()[^1].GetCustomAttributesData()); } else { @@ -219,7 +216,7 @@ public NullabilityInfo Create(PropertyInfo propertyInfo) return nullability; } - private bool IsPrivateOrInternalMethodAndAnnotationDisabled(MethodInfo method) + private bool IsPrivateOrInternalMethodAndAnnotationDisabled(MethodBase method) { if ((method.IsPrivate || method.IsFamilyAndAssembly || method.IsAssembly) && IsPublicOnly(method.IsPrivate, method.IsFamilyAndAssembly, method.IsAssembly, method.Module)) @@ -247,7 +244,7 @@ public NullabilityInfo Create(EventInfo eventInfo) EnsureIsSupported(); - return GetNullabilityInfo(eventInfo, eventInfo.EventHandlerType!, eventInfo.GetCustomAttributesData()); + return GetNullabilityInfo(eventInfo, eventInfo.EventHandlerType!, CreateParser(eventInfo.GetCustomAttributesData())); } /// @@ -267,13 +264,9 @@ public NullabilityInfo Create(FieldInfo fieldInfo) EnsureIsSupported(); - if (IsPrivateOrInternalFieldAndAnnotationDisabled(fieldInfo)) - { - return new NullabilityInfo(fieldInfo.FieldType, NullabilityState.Unknown, NullabilityState.Unknown, null, Array.Empty()); - } - IList attributes = fieldInfo.GetCustomAttributesData(); - NullabilityInfo nullability = GetNullabilityInfo(fieldInfo, fieldInfo.FieldType, attributes); + NullableAttributeStateParser parser = IsPrivateOrInternalFieldAndAnnotationDisabled(fieldInfo) ? NullableAttributeStateParser.Unknown : CreateParser(attributes); + NullabilityInfo nullability = GetNullabilityInfo(fieldInfo, fieldInfo.FieldType, parser); CheckNullabilityAttributes(nullability, attributes); return nullability; } @@ -341,13 +334,13 @@ private NotAnnotatedStatus PopulateAnnotationInfo(IList cus return NotAnnotatedStatus.None; } - private NullabilityInfo GetNullabilityInfo(MemberInfo memberInfo, Type type, IList customAttributes) + private NullabilityInfo GetNullabilityInfo(MemberInfo memberInfo, Type type, NullableAttributeStateParser parser) { int index = 0; - return GetNullabilityInfo(memberInfo, type, customAttributes, ref index); + return GetNullabilityInfo(memberInfo, type, parser, ref index); } - private NullabilityInfo GetNullabilityInfo(MemberInfo memberInfo, Type type, IList customAttributes, ref int index) + private NullabilityInfo GetNullabilityInfo(MemberInfo memberInfo, Type type, NullableAttributeStateParser parser, ref int index) { NullabilityState state = NullabilityState.Unknown; NullabilityInfo? elementState = null; @@ -370,19 +363,20 @@ private NullabilityInfo GetNullabilityInfo(MemberInfo memberInfo, Type type, ILi if (underlyingType.IsGenericType) { - index++; + ++index; } } else { - if (!ParseNullableState(customAttributes, index++, ref state)) + if (!parser.ParseNullableState(index++, ref state) + && GetNullableContext(memberInfo) is { } contextState) { - state = GetNullableContext(memberInfo); + state = contextState; } if (type.IsArray) { - elementState = GetNullabilityInfo(memberInfo, type.GetElementType()!, customAttributes, ref index); + elementState = GetNullabilityInfo(memberInfo, type.GetElementType()!, parser, ref index); } } @@ -393,7 +387,7 @@ private NullabilityInfo GetNullabilityInfo(MemberInfo memberInfo, Type type, ILi for (int i = 0; i < genericArguments.Length; i++) { - genericArgumentsState[i] = GetNullabilityInfo(memberInfo, genericArguments[i], customAttributes, ref index); + genericArgumentsState[i] = GetNullabilityInfo(memberInfo, genericArguments[i], parser, ref index); } } @@ -407,7 +401,7 @@ private NullabilityInfo GetNullabilityInfo(MemberInfo memberInfo, Type type, ILi return nullability; } - private static bool ParseNullableState(IList customAttributes, int index, ref NullabilityState state) + private static NullableAttributeStateParser CreateParser(IList customAttributes) { foreach (CustomAttributeData attribute in customAttributes) { @@ -415,26 +409,11 @@ private static bool ParseNullableState(IList customAttribut attribute.AttributeType.Namespace == CompilerServicesNameSpace && attribute.ConstructorArguments.Count == 1) { - object? o = attribute.ConstructorArguments[0].Value; - - if (o is byte b) - { - state = TranslateByte(b); - return true; - } - else if (o is ReadOnlyCollection args && - index < args.Count && - args[index].Value is byte elementB) - { - state = TranslateByte(elementB); - return true; - } - - break; + return new(attribute.ConstructorArguments[0].Value); } } - return false; + return new(null); } private void TryLoadGenericMetaTypeNullability(MemberInfo memberInfo, NullabilityInfo nullability) @@ -452,7 +431,7 @@ private void TryLoadGenericMetaTypeNullability(MemberInfo memberInfo, Nullabilit if (metaType != null) { - CheckGenericParameters(nullability, metaMember!, metaType); + CheckGenericParameters(nullability, metaMember!, metaType, memberInfo.ReflectedType); } } @@ -477,19 +456,14 @@ private static Type GetPropertyMetaType(PropertyInfo property) return property.GetSetMethod(true)!.GetParameters()[0].ParameterType; } - private void CheckGenericParameters(NullabilityInfo nullability, MemberInfo metaMember, Type metaType) + private void CheckGenericParameters(NullabilityInfo nullability, MemberInfo metaMember, Type metaType, Type? reflectedType) { if (metaType.IsGenericParameter) { - NullabilityState state = nullability.ReadState; - - if (state == NullabilityState.NotNull && !ParseNullableState(metaType.GetCustomAttributesData(), 0, ref state)) + if (nullability.ReadState == NullabilityState.NotNull) { - state = GetNullableContext(metaType); + TryUpdateGenericParameterNullability(nullability, metaType, reflectedType); } - - nullability.ReadState = state; - nullability.WriteState = state; } else if (metaType.ContainsGenericParameters) { @@ -499,37 +473,136 @@ private void CheckGenericParameters(NullabilityInfo nullability, MemberInfo meta for (int i = 0; i < genericArguments.Length; i++) { - if (genericArguments[i].IsGenericParameter) - { - int index = i + 1; - NullabilityInfo n = GetNullabilityInfo(metaMember, genericArguments[i], genericArguments[i].GetCustomAttributesData(), ref index); - nullability.GenericTypeArguments[i].ReadState = n.ReadState; - nullability.GenericTypeArguments[i].WriteState = n.WriteState; - } - else - { - UpdateGenericArrayElements(nullability.GenericTypeArguments[i].ElementType, metaMember, genericArguments[i]); - } + CheckGenericParameters(nullability.GenericTypeArguments[i], metaMember, genericArguments[i], reflectedType); } } - else + else if (nullability.ElementType is { } elementNullability && metaType.IsArray) { - UpdateGenericArrayElements(nullability.ElementType, metaMember, metaType); + CheckGenericParameters(elementNullability, metaMember, metaType.GetElementType()!, reflectedType); } } } - private void UpdateGenericArrayElements(NullabilityInfo? elementState, MemberInfo metaMember, Type metaType) + private bool TryUpdateGenericParameterNullability(NullabilityInfo nullability, Type genericParameter, Type? reflectedType) { - if (metaType.IsArray && elementState != null - && metaType.GetElementType()!.IsGenericParameter) + Debug.Assert(genericParameter.IsGenericParameter); + + if (reflectedType is not null + && !genericParameter.IsGenericMethodParameter + && TryUpdateGenericTypeParameterNullabilityFromReflectedType(nullability, genericParameter, reflectedType, reflectedType)) + { + return true; + } + + var state = NullabilityState.Unknown; + if (CreateParser(genericParameter.GetCustomAttributesData()).ParseNullableState(0, ref state)) + { + nullability.ReadState = state; + nullability.WriteState = state; + return true; + } + + if (GetNullableContext(genericParameter) is { } contextState) { - Type elementType = metaType.GetElementType()!; - int index = 0; - NullabilityInfo n = GetNullabilityInfo(metaMember, elementType, elementType.GetCustomAttributesData(), ref index); - elementState.ReadState = n.ReadState; - elementState.WriteState = n.WriteState; + nullability.ReadState = contextState; + nullability.WriteState = contextState; + return true; } + + return false; + } + + private bool TryUpdateGenericTypeParameterNullabilityFromReflectedType(NullabilityInfo nullability, Type genericParameter, Type context, Type reflectedType) + { + Debug.Assert(genericParameter.IsGenericParameter && !genericParameter.IsGenericMethodParameter); + + Type contextTypeDefinition = context.IsGenericType && !context.IsGenericTypeDefinition ? context.GetGenericTypeDefinition() : context; + if (genericParameter.DeclaringType == contextTypeDefinition) + { + return false; + } + + Type? baseType = contextTypeDefinition.BaseType; + if (baseType is null) + { + return false; + } + + if (!baseType.IsGenericType + || (baseType.IsGenericTypeDefinition ? baseType : baseType.GetGenericTypeDefinition()) != genericParameter.DeclaringType) + { + return TryUpdateGenericTypeParameterNullabilityFromReflectedType(nullability, genericParameter, baseType, reflectedType); + } + + Type[] genericArguments = baseType.GetGenericArguments(); + Type genericArgument = genericArguments[genericParameter.GenericParameterPosition]; + if (genericArgument.IsGenericParameter) + { + return TryUpdateGenericParameterNullability(nullability, genericArgument, reflectedType); + } + + NullableAttributeStateParser parser = CreateParser(contextTypeDefinition.GetCustomAttributesData()); + int nullabilityStateIndex = 1; // start at 1 since index 0 is the type itself + for (int i = 0; i < genericParameter.GenericParameterPosition; i++) + { + nullabilityStateIndex += CountNullabilityStates(genericArguments[i]); + } + return TryPopulateNullabilityInfo(nullability, parser, ref nullabilityStateIndex); + + static int CountNullabilityStates(Type type) + { + Type underlyingType = Nullable.GetUnderlyingType(type) ?? type; + if (underlyingType.IsGenericType) + { + int count = 1; + foreach (Type genericArgument in underlyingType.GetGenericArguments()) + { + count += CountNullabilityStates(genericArgument); + } + return count; + } + if (underlyingType.IsArray) + { + return 1 + CountNullabilityStates(underlyingType.GetElementType()!); + } + + return type.IsValueType ? 0 : 1; + } + } + + private bool TryPopulateNullabilityInfo(NullabilityInfo nullability, NullableAttributeStateParser parser, ref int index) + { + bool isValueType = nullability.Type.IsValueType; + if (!isValueType) + { + var state = NullabilityState.Unknown; + if (!parser.ParseNullableState(index, ref state)) + { + return false; + } + + nullability.ReadState = state; + nullability.WriteState = state; + } + + if (!isValueType || (Nullable.GetUnderlyingType(nullability.Type) ?? nullability.Type).IsGenericType) + { + index++; + } + + if (nullability.GenericTypeArguments.Length > 0) + { + foreach (NullabilityInfo genericTypeArgumentNullability in nullability.GenericTypeArguments) + { + TryPopulateNullabilityInfo(genericTypeArgumentNullability, parser, ref index); + } + } + else if (nullability.ElementType is { } elementTypeNullability) + { + TryPopulateNullabilityInfo(elementTypeNullability, parser, ref index); + } + + return true; } private static NullabilityState TranslateByte(object? value) @@ -544,5 +617,35 @@ private static NullabilityState TranslateByte(byte b) => 2 => NullabilityState.Nullable, _ => NullabilityState.Unknown }; + + private readonly struct NullableAttributeStateParser + { + private static readonly object UnknownByte = (byte)0; + + private readonly object? _nullableAttributeArgument; + + public NullableAttributeStateParser(object? nullableAttributeArgument) + { + this._nullableAttributeArgument = nullableAttributeArgument; + } + + public static NullableAttributeStateParser Unknown => new(UnknownByte); + + public bool ParseNullableState(int index, ref NullabilityState state) + { + switch (this._nullableAttributeArgument) + { + case byte b: + state = TranslateByte(b); + return true; + case ReadOnlyCollection args + when index < args.Count && args[index].Value is byte elementB: + state = TranslateByte(elementB); + return true; + default: + return false; + } + } + } } } diff --git a/src/libraries/System.Runtime/tests/System/Reflection/NullabilityInfoContextTests.cs b/src/libraries/System.Runtime/tests/System/Reflection/NullabilityInfoContextTests.cs index ce939d4526302..7dd057d8cca0a 100644 --- a/src/libraries/System.Runtime/tests/System/Reflection/NullabilityInfoContextTests.cs +++ b/src/libraries/System.Runtime/tests/System/Reflection/NullabilityInfoContextTests.cs @@ -5,6 +5,7 @@ using System.Diagnostics.CodeAnalysis; using System.Globalization; using System.IO.Enumeration; +using System.Runtime.Serialization; using System.Text.RegularExpressions; using Microsoft.DotNet.RemoteExecutor; using Xunit; @@ -30,10 +31,10 @@ public static IEnumerable FieldTestData() yield return new object[] { "FieldDisallowNull", NullabilityState.Nullable, NullabilityState.NotNull, typeof(string) }; yield return new object[] { "FieldAllowNull", NullabilityState.NotNull, NullabilityState.Nullable, typeof(string) }; yield return new object[] { "FieldDisallowNull2", NullabilityState.Nullable, NullabilityState.NotNull, typeof(string) }; - yield return new object[] { "FieldAllowNull2", NullabilityState.NotNull, NullabilityState.Nullable, typeof(string) }; + yield return new object[] { "FieldAllowNull2", NullabilityState.NotNull, NullabilityState.NotNull, typeof(string) }; yield return new object[] { "FieldNotNull", NullabilityState.NotNull, NullabilityState.Nullable, typeof(string) }; yield return new object[] { "FieldMaybeNull", NullabilityState.Nullable, NullabilityState.NotNull, typeof(string) }; - yield return new object[] { "FieldMaybeNull2", NullabilityState.Nullable, NullabilityState.NotNull, typeof(string) }; + yield return new object[] { "FieldMaybeNull2", NullabilityState.NotNull, NullabilityState.NotNull, typeof(string) }; yield return new object[] { "FieldNotNull2", NullabilityState.NotNull, NullabilityState.Nullable, typeof(string) }; } @@ -86,11 +87,12 @@ public static IEnumerable PropertyTestData() yield return new object[] { "PropertyDisallowNull", NullabilityState.Nullable, NullabilityState.NotNull, typeof(string) }; yield return new object[] { "PropertyAllowNull", NullabilityState.NotNull, NullabilityState.Nullable, typeof(string) }; yield return new object[] { "PropertyDisallowNull2", NullabilityState.Nullable, NullabilityState.NotNull, typeof(string) }; - yield return new object[] { "PropertyAllowNull2", NullabilityState.NotNull, NullabilityState.Nullable, typeof(string) }; + yield return new object[] { "PropertyAllowNull2", NullabilityState.NotNull, NullabilityState.NotNull, typeof(string) }; yield return new object[] { "PropertyNotNull", NullabilityState.NotNull, NullabilityState.Nullable, typeof(string) }; yield return new object[] { "PropertyMaybeNull", NullabilityState.Nullable, NullabilityState.NotNull, typeof(string) }; - yield return new object[] { "PropertyMaybeNull2", NullabilityState.Nullable, NullabilityState.NotNull, typeof(string) }; + yield return new object[] { "PropertyMaybeNull2", NullabilityState.NotNull, NullabilityState.NotNull, typeof(string) }; yield return new object[] { "PropertyNotNull2", NullabilityState.NotNull, NullabilityState.Nullable, typeof(string) }; + yield return new object[] { "Item", NullabilityState.Nullable, NullabilityState.NotNull, typeof(string) }; } [Theory] @@ -433,7 +435,7 @@ public void GenericFieldNullableValueTypeTest(string fieldName, NullabilityState Assert.Equal(type, nullability.Type); } - public static IEnumerable GenericNotnullConstraintFieldsTestData() + public static IEnumerable GenericNotNullConstraintFieldsTestData() { yield return new object[] { "FieldNullable", NullabilityState.Nullable, NullabilityState.Nullable, typeof(string) }; yield return new object[] { "FieldUnknown", NullabilityState.Unknown, NullabilityState.Unknown, typeof(string) }; @@ -441,7 +443,7 @@ public static IEnumerable GenericNotnullConstraintFieldsTestData() } [Theory] - [MemberData(nameof(GenericNotnullConstraintFieldsTestData))] + [MemberData(nameof(GenericNotNullConstraintFieldsTestData))] public void GenericNotNullConstraintFieldsTest(string fieldName, NullabilityState readState, NullabilityState writeState, Type type) { FieldInfo field = typeof(GenericTestConstrainedNotNull).GetField(fieldName, flags)!; @@ -570,7 +572,7 @@ public void GenericListAndDictionaryFieldTest() public static IEnumerable MethodReturnParameterTestData() { - yield return new object[] { "MethodReturnsUnknown", NullabilityState.Unknown, NullabilityState.Unknown}; + yield return new object[] { "MethodReturnsUnknown", NullabilityState.Unknown, NullabilityState.Unknown }; yield return new object[] { "MethodReturnsNullNon", NullabilityState.Nullable, NullabilityState.NotNull }; yield return new object[] { "MethodReturnsNullNull", NullabilityState.Nullable, NullabilityState.Nullable }; yield return new object[] { "MethodReturnsNonNull", NullabilityState.NotNull, NullabilityState.Nullable }; @@ -671,7 +673,7 @@ public void MethodParametersTest(string methodName, NullabilityState stringState public static IEnumerable MethodGenericParametersTestData() { - yield return new object[] { "MethodParametersUnknown", NullabilityState.Unknown, NullabilityState.Unknown, NullabilityState.Unknown, NullabilityState.Unknown}; + yield return new object[] { "MethodParametersUnknown", NullabilityState.Unknown, NullabilityState.Unknown, NullabilityState.Unknown, NullabilityState.Unknown }; yield return new object[] { "MethodArgsNullGenericNullDictValueGeneric", NullabilityState.Nullable, NullabilityState.NotNull, NullabilityState.Nullable, NullabilityState.Nullable }; yield return new object[] { "MethodArgsGenericDictValueNullGeneric", NullabilityState.Nullable, NullabilityState.NotNull, NullabilityState.Nullable, NullabilityState.NotNull }; } @@ -710,7 +712,7 @@ public void NullablePublicOnlyStringTypeTest(string methodName, NullabilityState Assert.Equal(param1State, param1.ReadState); Assert.Equal(param2State, param2.ReadState); Assert.Equal(param3State, param3.ReadState); - if (param2.ElementType != null) + if (param2.ElementType != null) { Assert.Equal(NullabilityState.Nullable, param2.ElementType.ReadState); } @@ -739,7 +741,7 @@ public void NullablePublicOnlyOtherTypesTest() PropertyInfo publicGetPrivateSetNullableProperty = typeof(FileSystemEntry).GetProperty("Directory", flags)!; info = nullabilityContext.Create(publicGetPrivateSetNullableProperty); Assert.Equal(NullabilityState.NotNull, info.ReadState); - Assert.Equal(NullabilityState.Unknown, info.WriteState); + Assert.Equal(NullabilityState.NotNull, info.WriteState); MethodInfo protectedNullableReturnMethod = type.GetMethod("GetPropertyImpl", flags)!; info = nullabilityContext.Create(protectedNullableReturnMethod.ReturnParameter); @@ -748,8 +750,8 @@ public void NullablePublicOnlyOtherTypesTest() MethodInfo privateValueTypeReturnMethod = type.GetMethod("BinarySearch", flags)!; info = nullabilityContext.Create(privateValueTypeReturnMethod.ReturnParameter); - Assert.Equal(NullabilityState.Unknown, info.ReadState); - Assert.Equal(NullabilityState.Unknown, info.WriteState); + Assert.Equal(NullabilityState.NotNull, info.ReadState); + Assert.Equal(NullabilityState.NotNull, info.WriteState); Type regexType = typeof(Regex); FieldInfo protectedInternalNullableField = regexType.GetField("pattern", flags)!; @@ -761,13 +763,18 @@ public void NullablePublicOnlyOtherTypesTest() info = nullabilityContext.Create(privateNullableField); Assert.Equal(NullabilityState.Unknown, info.ReadState); Assert.Equal(NullabilityState.Unknown, info.WriteState); + + ConstructorInfo privateConstructor = typeof(IndexOutOfRangeException) + .GetConstructor(BindingFlags.NonPublic | BindingFlags.Instance, new[] { typeof(SerializationInfo), typeof(StreamingContext) })!; + info = nullabilityContext.Create(privateConstructor.GetParameters()[0]); + Assert.Equal(NullabilityState.Unknown, info.WriteState); } public static IEnumerable DifferentContextTestData() { yield return new object[] { "PropertyDisabled", NullabilityState.Unknown, NullabilityState.Unknown, typeof(string) }; - yield return new object[] { "PropertyDisabledAllowNull", NullabilityState.Unknown, NullabilityState.Unknown, typeof(string) }; - yield return new object[] { "PropertyDisabledMaybeNull", NullabilityState.Unknown, NullabilityState.Unknown, typeof(string) }; + yield return new object[] { "PropertyDisabledAllowNull", NullabilityState.Unknown, NullabilityState.Nullable, typeof(string) }; + yield return new object[] { "PropertyDisabledMaybeNull", NullabilityState.Nullable, NullabilityState.Unknown, typeof(string) }; yield return new object[] { "PropertyEnabledAllowNull", NullabilityState.NotNull, NullabilityState.Nullable, typeof(string) }; yield return new object[] { "PropertyEnabledNotNull", NullabilityState.NotNull, NullabilityState.Nullable, typeof(string) }; yield return new object[] { "PropertyEnabledMaybeNull", NullabilityState.Nullable, NullabilityState.NotNull, typeof(string) }; @@ -984,6 +991,105 @@ public void TestValueTupleGenericTypeParameters(string fieldName, NullabilitySta Assert.Equal(param1, tupleInfo.GenericTypeArguments[0].ReadState); Assert.Equal(param2, tupleInfo.GenericTypeArguments[1].ReadState); } + + public static IEnumerable GenericInheritanceTestData() + { + yield return new object?[] { typeof(ListOfUnconstrained), NullabilityState.Nullable, null }; + yield return new object?[] { typeof(ListUnconstrainedOfNullable), NullabilityState.Nullable, null }; + yield return new object?[] { typeof(ListUnconstrainedOfNullableOfObject<>), NullabilityState.NotNull, null }; + yield return new object?[] { typeof(ListOfArrayOfNullableString), NullabilityState.NotNull, NullabilityState.Nullable }; + yield return new object?[] { typeof(ListOfNotNull>), NullabilityState.NotNull, NullabilityState.NotNull }; + yield return new object?[] { typeof(ListOfListOfObject), NullabilityState.NotNull, NullabilityState.NotNull }; + yield return new object?[] { typeof(ListMultiGenericOfNotNull), NullabilityState.NotNull, null }; + } + + [Theory] + [MemberData(nameof(GenericInheritanceTestData))] + [SkipOnMono("Nullability attributes trimmed on Mono")] + public void TestGenericInheritance(Type listType, NullabilityState parameterState, NullabilityState? subState) + { + var addParameterInfo = nullabilityContext.Create(listType.GetMethod("Add")!.GetParameters()[0]); + Validate(addParameterInfo); + + var copyToParameterInfo = nullabilityContext.Create( + listType.GetMethod("CopyTo", new[] { addParameterInfo.Type.MakeArrayType() })! + .GetParameters()[0]); + Assert.Equal(NullabilityState.NotNull, copyToParameterInfo.ReadState); + Assert.Equal(NullabilityState.NotNull, copyToParameterInfo.WriteState); + Assert.NotNull(copyToParameterInfo.ElementType); + Validate(copyToParameterInfo.ElementType!); + + void Validate(NullabilityInfo info) + { + Assert.Equal(parameterState, info.ReadState); + Assert.Equal(parameterState, info.WriteState); + if (subState != null) + { + NullabilityInfo subInfo = info.ElementType ?? info.GenericTypeArguments[0]; + Assert.Equal(subState, subInfo.ReadState); + Assert.Equal(subState, subInfo.WriteState); + Assert.True(info.GenericTypeArguments.Length <= 1); + } + else + { + Assert.Null(info.ElementType); + Assert.Empty(info.GenericTypeArguments); + } + } + } + + [Fact] + [SkipOnMono("Nullability attributes trimmed on Mono")] + public void TestDeeplyNestedGenericInheritance() + { + var copyToMethodInfo = nullabilityContext.Create( + typeof(ListOfTupleOfDictionaryOfStringNullableBoolIntNullableObject).GetMethod("CopyTo", new[] { typeof((Dictionary, int, object?)[]) })! + .GetParameters()[0]); + + Validate(copyToMethodInfo, typeof((Dictionary, int, object?)[]), NullabilityState.NotNull); + Assert.NotNull(copyToMethodInfo.ElementType); + + var tupleInfo = copyToMethodInfo.ElementType!; + Validate(tupleInfo, typeof((Dictionary, int, object?)), NullabilityState.NotNull); + + var dictionaryInfo = tupleInfo.GenericTypeArguments[0]; + Validate(dictionaryInfo, typeof(Dictionary), NullabilityState.NotNull); + + var stringInfo = dictionaryInfo.GenericTypeArguments[0]; + Validate(stringInfo, typeof(string), NullabilityState.NotNull); + + var nullableBoolInfo = dictionaryInfo.GenericTypeArguments[1]; + Validate(nullableBoolInfo, typeof(bool?), NullabilityState.Nullable); + + var intInfo = tupleInfo.GenericTypeArguments[1]; + Validate(intInfo, typeof(int), NullabilityState.NotNull); + + var objectInfo = tupleInfo.GenericTypeArguments[2]; + Validate(objectInfo, typeof(object), NullabilityState.Nullable); + + void Validate(NullabilityInfo info, Type type, NullabilityState state) + { + Assert.Equal(type, info.Type); + Assert.Equal(state, info.ReadState); + Assert.Equal(state, info.WriteState); + Assert.Equal(type.IsGenericType && type.GetGenericTypeDefinition() != typeof(Nullable<>) ? type.GetGenericArguments().Length : 0, info.GenericTypeArguments.Length); + Assert.Equal(type.IsArray, info.ElementType is not null); + } + } + + [Fact] + [SkipOnMono("Nullability attributes trimmed on Mono")] + public void TestNestedGenericInheritanceWithMultipleParameters() + { + var item3Info = nullabilityContext.Create(typeof(DerivesFromTupleOfNestedGenerics).GetProperty("Item3")!); + + Assert.Equal(typeof(IDisposable[]), item3Info.Type); + Assert.Equal(NullabilityState.Nullable, item3Info.ReadState); + Assert.Equal(NullabilityState.Unknown, item3Info.WriteState); // read-only property + + Assert.Equal(NullabilityState.NotNull, item3Info.ElementType!.ReadState); + Assert.Equal(NullabilityState.NotNull, item3Info.ElementType.WriteState); + } } #pragma warning disable CS0649, CS0067, CS0414 @@ -1023,7 +1129,7 @@ public class TypeWithNoContext [AllowNull] public string PropertyEnabledAllowNull { get; set; } [NotNull] public string? PropertyEnabledNotNull { get; set; } = null!; [DisallowNull] public string? PropertyEnabledDisallowNull { get; set; } = null!; - [MaybeNull] public string PropertyEnabledMaybeNull { get; set; } + [MaybeNull] public string PropertyEnabledMaybeNull { get; set; } public string? PropertyEnabledNullable { get; set; } public string PropertyEnabledNonNullable { get; set; } = null!; #nullable disable @@ -1059,14 +1165,15 @@ public void MethodParametersUnknown(string s, IDictionary dict) [AllowNull] public string PropertyAllowNull { get; set; } [NotNull] public string? PropertyNotNull { get; set; } [MaybeNull] public string PropertyMaybeNull { get; set; } - // only AllowNull matter + // only DisallowNull matters [AllowNull, DisallowNull] public string PropertyAllowNull2 { get; set; } - // only DisallowNull matter + // only AllowNull matters [AllowNull, DisallowNull] public string? PropertyDisallowNull2 { get; set; } - // only NotNull matter + // only NotNull matters [NotNull, MaybeNull] public string? PropertyNotNull2 { get; set; } - // only MaybeNull matter + // only NotNull matters [NotNull, MaybeNull] public string PropertyMaybeNull2 { get; set; } + [DisallowNull] public string? this[int i] { get => null; set { } } private protected string?[]?[]? PropertyJaggedArrayNullNullNull { get; set; } public static string?[]?[] PropertyJaggedArrayNullNullNon { get; set; } = null!; public string?[][]? PropertyJaggedArrayNullNonNull { get; set; } @@ -1174,7 +1281,7 @@ public void MethodParametersUnknown(T s, IDictionary dict) { } [DisallowNull] public T? FieldDisallowNull; [AllowNull] protected T FieldAllowNull; [NotNull] public T? FieldNotNull = default; - [MaybeNull] protected internal T FieldMaybeNull = default!; + [MaybeNull] protected internal T FieldMaybeNull = default!; public List FieldListOfT = default!; public Dictionary FieldDictionaryStringToT = default!; @@ -1213,4 +1320,33 @@ internal class GenericTestConstrainedStruct where T : struct public T PropertyNullableEnabled { get; set; } public T? PropertyNullable { get; set; } } + + public class ListOfUnconstrained : List { } + + public class ListUnconstrainedOfNullable : ListOfUnconstrained where T : class? { } + + public class ListUnconstrainedOfNullableOfObject : ListUnconstrainedOfNullable { } + + public class ListOfArrayOfNullableString : List { } + + public class ListOfNotNull : List where T : notnull { } + + public class ListOfListOfObject : List> { } + + public class ListMultiGenericOfNotNull : List + where T : class? + where U : class + where V : class? + { + } + + public class ListOfTupleOfDictionaryOfStringNullableBoolIntNullableObject : List<(Dictionary, int, object?)> { } + + public class DerivesFromTupleOfNestedGenerics : Tuple, Dictionary>, IDisposable[]?> + { + public DerivesFromTupleOfNestedGenerics(List item1, Dictionary> item2, IDisposable[]? item3) + : base(item1, item2, item3) + { + } + } }