Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update nullability convention to new nullability metadata #16385

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ public virtual ConventionSet CreateConventionSet()
var databaseGeneratedAttributeConvention = new DatabaseGeneratedAttributeConvention(Dependencies);
var requiredPropertyAttributeConvention = new RequiredPropertyAttributeConvention(Dependencies);
var nonNullableReferencePropertyConvention = new NonNullableReferencePropertyConvention(Dependencies);
var nonNullableNavigationConvention = new NonNullableNavigationConvention(Dependencies);
var maxLengthAttributeConvention = new MaxLengthAttributeConvention(Dependencies);
var stringLengthAttributeConvention = new StringLengthAttributeConvention(Dependencies);
var timestampAttributeConvention = new TimestampAttributeConvention(Dependencies);
Expand Down Expand Up @@ -171,11 +172,14 @@ public virtual ConventionSet CreateConventionSet()
conventionSet.ModelFinalizedConventions.Add(foreignKeyIndexConvention);
conventionSet.ModelFinalizedConventions.Add(foreignKeyPropertyDiscoveryConvention);
conventionSet.ModelFinalizedConventions.Add(servicePropertyDiscoveryConvention);
conventionSet.ModelFinalizedConventions.Add(nonNullableReferencePropertyConvention);
conventionSet.ModelFinalizedConventions.Add(nonNullableNavigationConvention);
conventionSet.ModelFinalizedConventions.Add(new ValidatingConvention(Dependencies));
// Don't add any more conventions to ModelFinalizedConventions after ValidatingConvention

conventionSet.NavigationAddedConventions.Add(backingFieldConvention);
conventionSet.NavigationAddedConventions.Add(new RequiredNavigationAttributeConvention(Dependencies));
conventionSet.NavigationAddedConventions.Add(new NonNullableNavigationConvention(Dependencies));
conventionSet.NavigationAddedConventions.Add(nonNullableNavigationConvention);
conventionSet.NavigationAddedConventions.Add(inversePropertyAttributeConvention);
conventionSet.NavigationAddedConventions.Add(foreignKeyPropertyDiscoveryConvention);
conventionSet.NavigationAddedConventions.Add(relationshipDiscoveryConvention);
Expand Down
119 changes: 101 additions & 18 deletions src/EFCore/Metadata/Conventions/NonNullableConventionBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.

using System;
using System.Collections.Generic;
using System.Linq;
using System.Reflection;
using JetBrains.Annotations;
using Microsoft.EntityFrameworkCore.Metadata.Builders;
using Microsoft.EntityFrameworkCore.Metadata.Conventions.Infrastructure;

namespace Microsoft.EntityFrameworkCore.Metadata.Conventions
Expand All @@ -13,11 +15,14 @@ namespace Microsoft.EntityFrameworkCore.Metadata.Conventions
/// A base type for conventions that configure model aspects based on whether the member type
/// is a non-nullable reference type.
/// </summary>
public abstract class NonNullableConventionBase
public abstract class NonNullableConventionBase : IModelFinalizedConvention
roji marked this conversation as resolved.
Show resolved Hide resolved
{
// For the interpretation of nullability metadata, see
// https://github.com/dotnet/roslyn/blob/master/docs/features/nullable-metadata.md

private const string StateAnnotationName = "NonNullableConventionState";
private const string NullableAttributeFullName = "System.Runtime.CompilerServices.NullableAttribute";
private Type _nullableAttrType;
private FieldInfo _nullableFlagsFieldInfo;
private const string NullableContextAttributeFullName = "System.Runtime.CompilerServices.NullableContextAttribute";

/// <summary>
/// Creates a new instance of <see cref="NonNullableConventionBase" />.
Expand All @@ -33,40 +38,118 @@ protected NonNullableConventionBase([NotNull] ProviderConventionSetBuilderDepend
/// </summary>
protected virtual ProviderConventionSetBuilderDependencies Dependencies { get; }

private byte? GetNullabilityContextFlag(NonNullabilityConventionState state, Attribute[] attributes)
{
if (attributes.FirstOrDefault(a => a.GetType().FullName == NullableContextAttributeFullName) is Attribute attribute)
{
var attributeType = attribute.GetType();

if (attributeType != state.NullableContextAttrType)
{
state.NullableContextFlagFieldInfo = attributeType.GetField("Flag");
state.NullableContextAttrType = attributeType;
}

if (state.NullableContextFlagFieldInfo?.GetValue(attribute) is byte flag)
{
return flag;
}
}

return null;
}

/// <summary>
/// Returns a value indicating whether the member type is a non-nullable reference type.
/// </summary>
/// <param name="modelBuilder"> The model builder used to build the model. </param>
/// <param name="memberInfo"> The member info. </param>
/// <returns> <c>true</c> if the member type is a non-nullable reference type. </returns>
protected virtual bool IsNonNullable([NotNull] MemberInfo memberInfo)
protected virtual bool IsNonNullable(
[NotNull] IConventionModelBuilder modelBuilder,
[NotNull] MemberInfo memberInfo)
{
var state = GetOrInitializeState(modelBuilder);

// For C# 8.0 nullable types, the C# currently synthesizes a NullableAttribute that expresses nullability into assemblies
// it produces. If the model is spread across more than one assembly, there will be multiple versions of this attribute,
// so look for it by name, caching to avoid reflection on every check.
// Note that this may change - if https://github.com/dotnet/corefx/issues/36222 is done we can remove all of this.
if (!(Attribute.GetCustomAttributes(memberInfo, true)
.FirstOrDefault(a => a.GetType().FullName == NullableAttributeFullName)
is { } attribute))

// First look for NullableAttribute on the member itself
if (Attribute.GetCustomAttributes(memberInfo, true)
.FirstOrDefault(a => a.GetType().FullName == NullableAttributeFullName) is Attribute attribute)
{
var attributeType = attribute.GetType();

if (attributeType != state.NullableAttrType)
{
state.NullableFlagsFieldInfo = attributeType.GetField("NullableFlags");
state.NullableAttrType = attributeType;
}

if (state.NullableFlagsFieldInfo?.GetValue(attribute) is byte[] flags
&& flags.FirstOrDefault() == 1)
{
return true;
}
}

// No attribute on the member, try to find a NullableContextAttribute on the declaring type
var type = memberInfo.DeclaringType;
if (type != null)
{
return false;
if (state.TypeNonNullabilityContextCache.TryGetValue(type, out var cachedTypeNonNullable))
{
return cachedTypeNonNullable;
}

var typeContextFlag = GetNullabilityContextFlag(state, Attribute.GetCustomAttributes(type));
if (typeContextFlag.HasValue)
{
return state.TypeNonNullabilityContextCache[type] = typeContextFlag.Value == 1;
}
}

var attributeType = attribute.GetType();
if (attributeType != _nullableAttrType)
// Not found at the type level, try at the module level
var module = memberInfo.Module;
if (!state.ModuleNonNullabilityContextCache.TryGetValue(module, out var moduleNonNullable))
{
_nullableFlagsFieldInfo = attributeType.GetField("NullableFlags");
_nullableAttrType = attributeType;
var moduleContextFlag = GetNullabilityContextFlag(state, Attribute.GetCustomAttributes(memberInfo.Module));
moduleNonNullable = state.ModuleNonNullabilityContextCache[module] =
moduleContextFlag.HasValue && moduleContextFlag == 1;
}

// For the interpretation of NullableFlags, see
// https://github.com/dotnet/roslyn/blob/master/docs/features/nullable-reference-types.md#annotations
if (_nullableFlagsFieldInfo?.GetValue(attribute) is byte[] flags
&& flags.FirstOrDefault() == 1)
if (type != null)
{
return true;
state.TypeNonNullabilityContextCache[type] = moduleNonNullable;
}

return false;
return moduleNonNullable;
}

private NonNullabilityConventionState GetOrInitializeState(IConventionModelBuilder modelBuilder)
=> (NonNullabilityConventionState)(
modelBuilder.Metadata.FindAnnotation(StateAnnotationName) ??
modelBuilder.Metadata.AddAnnotation(StateAnnotationName, new NonNullabilityConventionState())
).Value;

/// <summary>
/// Called after a model is finalized. Removes the cached state annotation used by this convention.
/// </summary>
/// <param name="modelBuilder"> The builder for the model. </param>
/// <param name="context"> Additional information associated with convention execution. </param>
public virtual void ProcessModelFinalized(IConventionModelBuilder modelBuilder, IConventionContext<IConventionModelBuilder> context)
=> modelBuilder.Metadata.RemoveAnnotation(StateAnnotationName);

private class NonNullabilityConventionState
{
public Type NullableAttrType;
public Type NullableContextAttrType;
public FieldInfo NullableFlagsFieldInfo;
public FieldInfo NullableContextFlagFieldInfo;
public Dictionary<Type, bool> TypeNonNullabilityContextCache { get; } = new Dictionary<Type, bool>();
public Dictionary<Module, bool> ModuleNonNullabilityContextCache { get; } = new Dictionary<Module, bool>();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,9 @@ public virtual void ProcessNavigationAdded(
Check.NotNull(relationshipBuilder, nameof(relationshipBuilder));
Check.NotNull(navigation, nameof(navigation));

if (!IsNonNullable(navigation)
|| navigation.IsCollection())
var modelBuilder = relationshipBuilder.ModelBuilder;

if (!IsNonNullable(modelBuilder, navigation) || navigation.IsCollection())
{
return;
}
Expand All @@ -51,7 +52,7 @@ public virtual void ProcessNavigationAdded(
var inverse = navigation.FindInverse();
if (inverse != null)
{
if (IsNonNullable(inverse))
if (IsNonNullable(modelBuilder, inverse))
{
Dependencies.Logger.NonNullableReferenceOnBothNavigations(navigation, inverse);
return;
Expand Down Expand Up @@ -82,9 +83,9 @@ public virtual void ProcessNavigationAdded(
context.StopProcessingIfChanged(relationshipBuilder.Metadata.DependentToPrincipal);
}

private bool IsNonNullable(IConventionNavigation navigation)
private bool IsNonNullable(IConventionModelBuilder modelBuilder, IConventionNavigation navigation)
=> navigation.DeclaringEntityType.HasClrType()
&& navigation.DeclaringEntityType.GetRuntimeProperties().Find(navigation.Name) is PropertyInfo propertyInfo
&& IsNonNullable(propertyInfo);
&& IsNonNullable(modelBuilder, propertyInfo);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ private void Process(IConventionPropertyBuilder propertyBuilder)
// If the model is spread across multiple assemblies, it may contain different NullableAttribute types as
// the compiler synthesizes them for each assembly.
if (propertyBuilder.Metadata.GetIdentifyingMemberInfo() is MemberInfo memberInfo
&& IsNonNullable(memberInfo))
&& IsNonNullable(propertyBuilder.ModelBuilder, memberInfo))
{
propertyBuilder.IsRequired(true);
}
Expand Down