Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ensure proper receiver value is used for a constrained call invocation #65642

Merged
merged 3 commits into from
Dec 6, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,12 @@ public static RefKind GetRefKind(this BoundExpression node)
case BoundKind.PropertyAccess:
return ((BoundPropertyAccess)node).PropertySymbol.RefKind;

case BoundKind.IndexerAccess:
return ((BoundIndexerAccess)node).Indexer.RefKind;

case BoundKind.ImplicitIndexerAccess:
return ((BoundImplicitIndexerAccess)node).IndexerOrSliceAccess.GetRefKind();

case BoundKind.ObjectInitializerMember:
var member = (BoundObjectInitializerMember)node;
if (member.HasErrors)
Expand Down
4 changes: 2 additions & 2 deletions src/Compilers/CSharp/Portable/BoundTree/BoundNodes.xml
Original file line number Diff line number Diff line change
Expand Up @@ -1716,10 +1716,10 @@
<Field Name="Id" Type="int"/>
</Node>

<!-- This node represents a complex receiver for a conditional access.
<!-- This node represents a complex receiver for a call, or a conditional access.
At runtime, when its type is a value type, ValueTypeReceiver should be used as a receiver.
Otherwise, ReferenceTypeReceiver should be used.
This kind of receiver is created only by Async rewriter.
This kind of receiver is created only by SpillSequenceSpiller rewriter.
-->
<Node Name="BoundComplexConditionalReceiver" Base="BoundExpression">
<Field Name="Type" Type="TypeSymbol" Override="true" Null="disallow"/>
Expand Down
146 changes: 145 additions & 1 deletion src/Compilers/CSharp/Portable/CodeGen/EmitExpression.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1574,6 +1574,7 @@ private void EmitInstanceCallExpression(BoundCall call, UseKind useKind)
var receiver = call.ReceiverOpt;
var arguments = call.Arguments;
LocalDefinition tempOpt = null;
LocalDefinition cloneTemp = null;

Debug.Assert(!method.IsStatic && method.RequiresInstanceReceiver);

Expand Down Expand Up @@ -1655,6 +1656,51 @@ private void EmitInstanceCallExpression(BoundCall call, UseKind useKind)
CallKind.ConstrainedCallVirt;

tempOpt = EmitReceiverRef(receiver, callKind == CallKind.ConstrainedCallVirt ? AddressKind.Constrained : AddressKind.Writeable);

if (callKind == CallKind.ConstrainedCallVirt && tempOpt is null && !receiverType.IsValueType &&
!ReceiverIsKnownToReferToTempIfReferenceType(call.ReceiverOpt) &&
!IsSafeToDereferenceReceiverRefAfterEvaluatingArguments(call.Arguments))
{
// A case where T is actually a class must be handled specially.
// Taking a reference to a class instance is fragile because the value behind the
// reference might change while arguments are evaluated. However, the call should be
// performed on the instance that is behind reference at the time we push the
// reference to the stack. So, for a class we need to emit a reference to a temporary
// location, rather than to the original location

// Struct values are never nulls.
// We will emit a check for such case, but the check is really a JIT-time
// constant since JIT will know if T is a struct or not.

// if ((object)default(T) == null)
// {
// temp = receiverRef
// receiverRef = ref temp
// }

object whenNotNullLabel = null;

if (!receiverType.IsReferenceType)
{
// if ((object)default(T) == null)
EmitDefaultValue(receiverType, true, receiver.Syntax);
EmitBox(receiverType, receiver.Syntax);
whenNotNullLabel = new object();
_builder.EmitBranch(ILOpCode.Brtrue, whenNotNullLabel);
}

// temp = receiverRef
// receiverRef = ref temp
EmitLoadIndirect(receiverType, receiver.Syntax);
cloneTemp = AllocateTemp(receiverType, receiver.Syntax);
_builder.EmitLocalStore(cloneTemp);
_builder.EmitLocalAddress(cloneTemp);

if (whenNotNullLabel is not null)
{
_builder.MarkLabel(whenNotNullLabel);
}
}
}

// When emitting a callvirt to a virtual method we always emit the method info of the
Expand Down Expand Up @@ -1730,6 +1776,104 @@ private void EmitInstanceCallExpression(BoundCall call, UseKind useKind)
EmitCallCleanup(call.Syntax, useKind, method);

FreeOptTemp(tempOpt);
FreeOptTemp(cloneTemp);
}

internal static bool IsPossibleReferenceTypeReceiverOfConstrainedCall(BoundExpression receiver)
{
var receiverType = receiver.Type;

if (receiverType.IsVerifierReference() || receiverType.IsVerifierValue())
{
return false;
}

return !receiverType.IsValueType;
}

internal static bool ReceiverIsKnownToReferToTempIfReferenceType(BoundExpression receiver)
{
while (receiver is BoundSequence sequence)
{
receiver = sequence.Value;
}

if (receiver is
BoundLocal { LocalSymbol.IsKnownToReferToTempIfReferenceType: true } or
BoundComplexConditionalReceiver)
{
return true;
}

return false;
}

internal static bool IsSafeToDereferenceReceiverRefAfterEvaluatingArguments(ImmutableArray<BoundExpression> arguments)
{
return arguments.All(isSafeToDereferenceReceiverRefAfterEvaluatingArgument);

static bool isSafeToDereferenceReceiverRefAfterEvaluatingArgument(BoundExpression expression)
{
var current = expression;
while (true)
{
if (current.ConstantValue != null)
{
return true;
}

switch (current.Kind)
{
default:
return false;
case BoundKind.TypeExpression:
case BoundKind.Parameter:
case BoundKind.Local:
case BoundKind.ThisReference:
return true;
case BoundKind.FieldAccess:
{
var field = (BoundFieldAccess)current;
current = field.ReceiverOpt;
if (current is null)
{
return true;
}

break;
}
case BoundKind.PassByCopy:
current = ((BoundPassByCopy)current).Expression;
break;
case BoundKind.BinaryOperator:
{
BoundBinaryOperator b = (BoundBinaryOperator)current;
Debug.Assert(!b.OperatorKind.IsUserDefined());

if (b.OperatorKind.IsUserDefined() || !isSafeToDereferenceReceiverRefAfterEvaluatingArgument(b.Right))
{
return false;
}

current = b.Left;
break;
}
case BoundKind.Conversion:
{
BoundConversion conv = (BoundConversion)current;
Debug.Assert(!conv.ConversionKind.IsUserDefinedConversion());

if (conv.ConversionKind.IsUserDefinedConversion())
{
return false;
}

current = conv.Operand;
break;
}
}
}
}
}

private bool IsReadOnlyCall(MethodSymbol method, NamedTypeSymbol methodContainingType)
Expand Down Expand Up @@ -1759,7 +1903,7 @@ private bool IsReadOnlyCall(MethodSymbol method, NamedTypeSymbol methodContainin
// returns true when receiver is already a ref.
// in such cases calling through a ref could be preferred over
// calling through indirectly loaded value.
private bool IsRef(BoundExpression receiver)
internal static bool IsRef(BoundExpression receiver)
{
switch (receiver.Kind)
{
Expand Down
30 changes: 27 additions & 3 deletions src/Compilers/CSharp/Portable/CodeGen/Optimizer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1108,7 +1108,7 @@ public override BoundNode VisitCall(BoundCall node)
// assume we will need an address (that will prevent scheduling of receiver).
if (method.RequiresInstanceReceiver)
{
receiver = VisitCallReceiver(receiver);
receiver = VisitCallOrConditionalAccessReceiver(receiver, node);
}
else
{
Expand All @@ -1132,9 +1132,28 @@ public override BoundNode VisitCall(BoundCall node)
return node.Update(receiver, method, rewrittenArguments);
}

private BoundExpression VisitCallReceiver(BoundExpression receiver)
private BoundExpression VisitCallOrConditionalAccessReceiver(BoundExpression receiver, BoundCall callOpt)
{
var receiverType = receiver.Type;

if (callOpt is { } call &&
CodeGenerator.IsRef(receiver) &&
CodeGenerator.IsPossibleReferenceTypeReceiverOfConstrainedCall(receiver) &&
!CodeGenerator.IsSafeToDereferenceReceiverRefAfterEvaluatingArguments(call.Arguments))
{
var unwrappedSequence = receiver;

while (unwrappedSequence is BoundSequence sequence)
{
unwrappedSequence = sequence.Value;
}

if (unwrappedSequence is BoundLocal { LocalSymbol: { RefKind: not RefKind.None } localSymbol })
{
ShouldNotSchedule(localSymbol); // Otherwise CodeGenerator is unable to apply proper fixups
}
}

ExprContext context;

if (receiverType.IsReferenceType)
Expand Down Expand Up @@ -1494,7 +1513,7 @@ public override BoundNode VisitNullCoalescingOperator(BoundNullCoalescingOperato
public override BoundNode VisitLoweredConditionalAccess(BoundLoweredConditionalAccess node)
{
var origStack = StackDepth();
BoundExpression receiver = VisitCallReceiver(node.Receiver);
BoundExpression receiver = VisitCallOrConditionalAccessReceiver(node.Receiver, callOpt: null);

var cookie = GetStackStateCookie(); // implicit branch here

Expand Down Expand Up @@ -2210,6 +2229,11 @@ internal override bool IsPinned
get { return false; }
}

internal override bool IsKnownToReferToTempIfReferenceType
{
get { return false; }
}

public override Symbol ContainingSymbol
{
get { throw new NotImplementedException(); }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ public override bool Equals(Symbol obj, TypeCompareKind compareKind)
internal override bool IsCompilerGenerated => true;
internal override bool IsImportedFromMetadata => false;
internal override bool IsPinned => false;
internal override bool IsKnownToReferToTempIfReferenceType => false;
public override RefKind RefKind => RefKind.None;
internal override SynthesizedLocalKind SynthesizedKind => throw ExceptionUtilities.Unreachable();
internal override ConstantValue GetConstantValue(SyntaxNode node, LocalSymbol inProgress, BindingDiagnosticBag diagnostics = null) => null;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,7 @@ private BoundExpression MakePropertyAssignment(
ArrayBuilder<LocalSymbol>? argTempsBuilder = null;
arguments = VisitArgumentsAndCaptureReceiverIfNeeded(
ref rewrittenReceiver,
captureReceiverForMultipleInvocations: false,
captureReceiverMode: ReceiverCaptureMode.Default,
arguments,
property,
argsToParamsOpt,
Expand Down
Loading