Skip to content

Commit

Permalink
Limit amount of tokens used to calculate LCS distance (#68151)
Browse files Browse the repository at this point in the history
  • Loading branch information
tmat authored May 11, 2023
1 parent 63019e9 commit b404fbe
Show file tree
Hide file tree
Showing 6 changed files with 86 additions and 104 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

using System;
using System.Collections.Immutable;
using System.Linq;
using Microsoft.CodeAnalysis.Differencing;
using Microsoft.CodeAnalysis.Test.Utilities;
using Roslyn.Test.Utilities;
Expand Down Expand Up @@ -84,7 +85,7 @@ public void GetSequenceEdits4()
}

[Fact]
public void ComputeDistance1()
public void ComputeDistance_Nodes()
{
var distance = SyntaxComparer.ComputeDistance(
new[] { MakeLiteral(0), MakeLiteral(1), MakeLiteral(2) },
Expand All @@ -94,17 +95,7 @@ public void ComputeDistance1()
}

[Fact]
public void ComputeDistance2()
{
var distance = SyntaxComparer.ComputeDistance(
ImmutableArray.Create(MakeLiteral(0), MakeLiteral(1), MakeLiteral(2)),
ImmutableArray.Create(MakeLiteral(1), MakeLiteral(3)));

Assert.Equal(0.67, Math.Round(distance, 2));
}

[Fact]
public void ComputeDistance3()
public void ComputeDistance_Tokens()
{
var distance = SyntaxComparer.ComputeDistance(
new[] { SyntaxFactory.Token(SyntaxKind.PublicKeyword), SyntaxFactory.Token(SyntaxKind.StaticKeyword), SyntaxFactory.Token(SyntaxKind.AsyncKeyword) },
Expand All @@ -113,16 +104,6 @@ public void ComputeDistance3()
Assert.Equal(0.33, Math.Round(distance, 2));
}

[Fact]
public void ComputeDistance4()
{
var distance = SyntaxComparer.ComputeDistance(
ImmutableArray.Create(SyntaxFactory.Token(SyntaxKind.PublicKeyword), SyntaxFactory.Token(SyntaxKind.StaticKeyword), SyntaxFactory.Token(SyntaxKind.AsyncKeyword)),
ImmutableArray.Create(SyntaxFactory.Token(SyntaxKind.StaticKeyword), SyntaxFactory.Token(SyntaxKind.PublicKeyword), SyntaxFactory.Token(SyntaxKind.AsyncKeyword)));

Assert.Equal(0.33, Math.Round(distance, 2));
}

[Fact]
public void ComputeDistance_Token()
{
Expand All @@ -141,18 +122,6 @@ public void ComputeDistance_Node()
public void ComputeDistance_Null()
{
var distance = SyntaxComparer.ComputeDistance(
default,
ImmutableArray.Create(SyntaxFactory.Token(SyntaxKind.StaticKeyword)));

Assert.Equal(1, Math.Round(distance, 2));

distance = SyntaxComparer.ComputeDistance(
default,
ImmutableArray.Create(MakeLiteral(0)));

Assert.Equal(1, Math.Round(distance, 2));

distance = SyntaxComparer.ComputeDistance(
null,
Array.Empty<SyntaxNode>());

Expand All @@ -176,5 +145,20 @@ public void ComputeDistance_Null()

Assert.Equal(0, Math.Round(distance, 2));
}

[Fact]
public void ComputeDistance_LongSequences()
{
var t1 = SyntaxFactory.Token(SyntaxKind.PublicKeyword);
var t2 = SyntaxFactory.Token(SyntaxKind.PrivateKeyword);
var t3 = SyntaxFactory.Token(SyntaxKind.ProtectedKeyword);

var distance = SyntaxComparer.ComputeDistance(
Enumerable.Range(0, 10000).Select(i => i < 2000 ? t1 : t2),
Enumerable.Range(0, 10000).Select(i => i < 2000 ? t1 : t3));

// long sequences are indistinguishable if they have common prefix shorter then threshold:
Assert.Equal(0, distance);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ Namespace Microsoft.CodeAnalysis.VisualBasic.EditAndContinue.UnitTests
End Sub

<Fact>
Public Sub ComputeDistance1()
Public Sub ComputeDistance_Nodes()
Dim distance = SyntaxComparer.ComputeDistance(
{MakeLiteral(0), MakeLiteral(1), MakeLiteral(2)},
{MakeLiteral(1), MakeLiteral(3)})
Expand All @@ -84,16 +84,7 @@ Namespace Microsoft.CodeAnalysis.VisualBasic.EditAndContinue.UnitTests
End Sub

<Fact>
Public Sub ComputeDistance2()
Dim distance = SyntaxComparer.ComputeDistance(
ImmutableArray.Create(MakeLiteral(0), MakeLiteral(1), MakeLiteral(2)),
ImmutableArray.Create(MakeLiteral(1), MakeLiteral(3)))

Assert.Equal(0.67, Math.Round(distance, 2))
End Sub

<Fact>
Public Sub ComputeDistance3()
Public Sub ComputeDistance_Tokens()
Dim distance = SyntaxComparer.ComputeDistance(
{SyntaxFactory.Token(SyntaxKind.PublicKeyword), SyntaxFactory.Token(SyntaxKind.StaticKeyword), SyntaxFactory.Token(SyntaxKind.AsyncKeyword)},
{SyntaxFactory.Token(SyntaxKind.StaticKeyword), SyntaxFactory.Token(SyntaxKind.PublicKeyword), SyntaxFactory.Token(SyntaxKind.AsyncKeyword)})
Expand Down Expand Up @@ -160,5 +151,19 @@ Namespace Microsoft.CodeAnalysis.VisualBasic.EditAndContinue.UnitTests

Assert.Equal(0.0, Math.Round(distance, 2))
End Sub

<Fact>
Public Sub ComputeDistance_LongSequences()
Dim t1 = SyntaxFactory.Token(SyntaxKind.PublicKeyword)
Dim t2 = SyntaxFactory.Token(SyntaxKind.PrivateKeyword)
Dim t3 = SyntaxFactory.Token(SyntaxKind.ProtectedKeyword)

Dim distance = SyntaxComparer.ComputeDistance(
Enumerable.Range(0, 10000).Select(Function(i) If(i < 2000, t1, t2)),
Enumerable.Range(0, 10000).Select(Function(i) If(i < 2000, t1, t3)))

' long sequences are indistinguishable if they have common prefix shorter then threshold
Assert.Equal(0.0, distance)
End Sub
End Class
End Namespace
30 changes: 9 additions & 21 deletions src/Features/CSharp/Portable/EditAndContinue/SyntaxComparer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
using System.Collections.Immutable;
using System.Diagnostics;
using System.Diagnostics.CodeAnalysis;
using System.Linq;
using Microsoft.CodeAnalysis.CSharp.Extensions;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using Microsoft.CodeAnalysis.Differencing;
Expand Down Expand Up @@ -1593,25 +1594,21 @@ public static double ComputeDistance(SyntaxNode? oldNode, SyntaxNode? newNode)
/// Distance is a number within [0, 1], the smaller the more similar the tokens are.
/// </remarks>
public static double ComputeDistance(SyntaxToken oldToken, SyntaxToken newToken)
=> LongestCommonSubstring.ComputeDistance(oldToken.Text, newToken.Text);
=> LongestCommonSubstring.ComputePrefixDistance(
oldToken.Text, Math.Min(oldToken.Text.Length, LongestCommonSubsequence.MaxSequenceLengthForDistanceCalculation),
newToken.Text, Math.Min(newToken.Text.Length, LongestCommonSubsequence.MaxSequenceLengthForDistanceCalculation));

/// <summary>
/// Calculates the distance between two sequences of syntax tokens, disregarding trivia.
/// </summary>
/// <remarks>
/// Distance is a number within [0, 1], the smaller the more similar the sequences are.
/// </remarks>
public static double ComputeDistance(IEnumerable<SyntaxToken>? oldTokens, IEnumerable<SyntaxToken>? newTokens)
=> LcsTokens.Instance.ComputeDistance(oldTokens.AsImmutableOrEmpty(), newTokens.AsImmutableOrEmpty());
private static ImmutableArray<T> CreateArrayForDistanceCalculation<T>(IEnumerable<T>? enumerable)
=> enumerable is null ? ImmutableArray<T>.Empty : enumerable.Take(LongestCommonSubsequence.MaxSequenceLengthForDistanceCalculation).ToImmutableArray();

/// <summary>
/// Calculates the distance between two sequences of syntax tokens, disregarding trivia.
/// </summary>
/// <remarks>
/// Distance is a number within [0, 1], the smaller the more similar the sequences are.
/// </remarks>
public static double ComputeDistance(ImmutableArray<SyntaxToken> oldTokens, ImmutableArray<SyntaxToken> newTokens)
=> LcsTokens.Instance.ComputeDistance(oldTokens.NullToEmpty(), newTokens.NullToEmpty());
public static double ComputeDistance(IEnumerable<SyntaxToken>? oldTokens, IEnumerable<SyntaxToken>? newTokens)
=> LcsTokens.Instance.ComputeDistance(CreateArrayForDistanceCalculation(oldTokens), CreateArrayForDistanceCalculation(newTokens));

/// <summary>
/// Calculates the distance between two sequences of syntax nodes, disregarding trivia.
Expand All @@ -1620,16 +1617,7 @@ public static double ComputeDistance(ImmutableArray<SyntaxToken> oldTokens, Immu
/// Distance is a number within [0, 1], the smaller the more similar the sequences are.
/// </remarks>
public static double ComputeDistance(IEnumerable<SyntaxNode>? oldNodes, IEnumerable<SyntaxNode>? newNodes)
=> LcsNodes.Instance.ComputeDistance(oldNodes.AsImmutableOrEmpty(), newNodes.AsImmutableOrEmpty());

/// <summary>
/// Calculates the distance between two sequences of syntax tokens, disregarding trivia.
/// </summary>
/// <remarks>
/// Distance is a number within [0, 1], the smaller the more similar the sequences are.
/// </remarks>
public static double ComputeDistance(ImmutableArray<SyntaxNode> oldNodes, ImmutableArray<SyntaxNode> newNodes)
=> LcsNodes.Instance.ComputeDistance(oldNodes.NullToEmpty(), newNodes.NullToEmpty());
=> LcsNodes.Instance.ComputeDistance(CreateArrayForDistanceCalculation(oldNodes), CreateArrayForDistanceCalculation(newNodes));

/// <summary>
/// Calculates the edits that transform one sequence of syntax nodes to another, disregarding trivia.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ Imports Microsoft.CodeAnalysis.VisualBasic.Syntax
Namespace Microsoft.CodeAnalysis.VisualBasic.EditAndContinue

Friend NotInheritable Class SyntaxComparer

Inherits AbstractSyntaxComparer

Friend Shared ReadOnly TopLevel As SyntaxComparer = New SyntaxComparer(Nothing, Nothing, Nothing, Nothing, compareStatementSyntax:=False)
Expand Down Expand Up @@ -1389,17 +1388,13 @@ Namespace Microsoft.CodeAnalysis.VisualBasic.EditAndContinue
''' Distance is a number within [0, 1], the smaller the more similar the tokens are.
''' </remarks>
Public Overloads Shared Function ComputeDistance(oldToken As SyntaxToken, newToken As SyntaxToken) As Double
Return LongestCommonSubstring.ComputeDistance(oldToken.ValueText, newToken.ValueText)
Return LongestCommonSubstring.ComputePrefixDistance(
oldToken.Text, Math.Min(oldToken.Text.Length, LongestCommonSubsequence.MaxSequenceLengthForDistanceCalculation),
newToken.Text, Math.Min(newToken.Text.Length, LongestCommonSubsequence.MaxSequenceLengthForDistanceCalculation))
End Function

''' <summary>
''' Calculates the distance between two sequences of syntax tokens, disregarding trivia.
''' </summary>
''' <remarks>
''' Distance is a number within [0, 1], the smaller the more similar the sequences are.
''' </remarks>
Public Overloads Shared Function ComputeDistance(oldTokens As IEnumerable(Of SyntaxToken), newTokens As IEnumerable(Of SyntaxToken)) As Double
Return ComputeDistance(oldTokens.AsImmutableOrNull(), newTokens.AsImmutableOrNull())
Private Shared Function CreateArrayForDistanceCalculation(Of T)(enumerable As IEnumerable(Of T)) As ImmutableArray(Of T)
Return If(enumerable Is Nothing, ImmutableArray(Of T).Empty, enumerable.Take(LongestCommonSubsequence.MaxSequenceLengthForDistanceCalculation).ToImmutableArray())
End Function

''' <summary>
Expand All @@ -1408,8 +1403,8 @@ Namespace Microsoft.CodeAnalysis.VisualBasic.EditAndContinue
''' <remarks>
''' Distance is a number within [0, 1], the smaller the more similar the sequences are.
''' </remarks>
Public Overloads Shared Function ComputeDistance(oldTokens As ImmutableArray(Of SyntaxToken), newTokens As ImmutableArray(Of SyntaxToken)) As Double
Return LcsTokens.Instance.ComputeDistance(oldTokens.NullToEmpty(), newTokens.NullToEmpty())
Public Overloads Shared Function ComputeDistance(oldTokens As IEnumerable(Of SyntaxToken), newTokens As IEnumerable(Of SyntaxToken)) As Double
Return LcsTokens.Instance.ComputeDistance(CreateArrayForDistanceCalculation(oldTokens), CreateArrayForDistanceCalculation(newTokens))
End Function

''' <summary>
Expand All @@ -1419,17 +1414,7 @@ Namespace Microsoft.CodeAnalysis.VisualBasic.EditAndContinue
''' Distance is a number within [0, 1], the smaller the more similar the sequences are.
''' </remarks>
Public Overloads Shared Function ComputeDistance(oldTokens As IEnumerable(Of SyntaxNode), newTokens As IEnumerable(Of SyntaxNode)) As Double
Return ComputeDistance(oldTokens.AsImmutableOrNull(), newTokens.AsImmutableOrNull())
End Function

''' <summary>
''' Calculates the distance between two sequences of syntax nodes, disregarding trivia.
''' </summary>
''' <remarks>
''' Distance is a number within [0, 1], the smaller the more similar the sequences are.
''' </remarks>
Public Overloads Shared Function ComputeDistance(oldTokens As ImmutableArray(Of SyntaxNode), newTokens As ImmutableArray(Of SyntaxNode)) As Double
Return LcsNodes.Instance.ComputeDistance(oldTokens.NullToEmpty(), newTokens.NullToEmpty())
Return LcsNodes.Instance.ComputeDistance(CreateArrayForDistanceCalculation(oldTokens), CreateArrayForDistanceCalculation(newTokens))
End Function

''' <summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,18 @@
using System.Collections.Generic;
using System.Diagnostics;
using Microsoft.CodeAnalysis.PooledObjects;
using Roslyn.Utilities;

namespace Microsoft.CodeAnalysis.Differencing
{
internal abstract class LongestCommonSubsequence
{
/// <summary>
/// Limit the number of tokens used to compute distance between sequences of tokens so that
/// we always use the pooled buffers. The combined length of the two sequences being compared
/// must be less than <see cref="VBuffer.PooledSegmentMaxDepthThreshold"/>.
/// </summary>
public const int MaxSequenceLengthForDistanceCalculation = VBuffer.PooledSegmentMaxDepthThreshold / 2;

// Define the pool in a non-generic base class to allow sharing among instantiations.
private static readonly ObjectPool<VBuffer> s_pool = new(() => new VBuffer());

Expand All @@ -40,17 +46,31 @@ protected sealed class VBuffer
/// For 150 it'd be 91KB, which would be allocated on LOH.
/// The buffers grow by factor of <see cref="GrowFactor"/>, so the next buffer will be allocated on LOH.
/// </summary>
private const int FirstBufferMaxDepth = 100;
public const int FirstSegmentMaxDepth = 100;

// 3 + Sum { d = 1..maxDepth : 2*d+1 } = (maxDepth + 1)^2 + 2
private const int FirstBufferLength = (FirstBufferMaxDepth + 1) * (FirstBufferMaxDepth + 1) + 2;

internal const int GrowFactor = 2;
private const int FirstSegmentLength = (FirstSegmentMaxDepth + 1) * (FirstSegmentMaxDepth + 1) + 2;

// Segment Segment Total buffer
// MaxDepth length length
// ---------------------------------------
// 100 10,204 10,204
// 150 12,600 22,804
// 225 28,275 51,079
// 338 63,845 114,924
// 507 143,143 258,067
// 761 322,580 580,647
// 1142 725,805 1,306,452
// 1713 1,631,347 2,937,799 <-- last pooled segment
// 2570 3,672,245 6,610,044
// 3855 8,258,695 14,868,739
internal const double GrowFactor = 1.5;

/// <summary>
/// Do not pool segments that are too large.
/// Do not expand pooled buffers to more than ~12 MB total size (sum of all linked segment sizes).
/// This threshold is achieved when <see cref="MaxDepth"/> is greater than <see cref="PooledSegmentMaxDepthThreshold"/> = sqrt(size_limit / sizeof(int)).
/// </summary>
internal const int MaxPooledBufferSize = 1024 * 1024;
internal const int PooledSegmentMaxDepthThreshold = 1800;

public VBuffer Previous { get; private set; }
public VBuffer Next { get; private set; }
Expand All @@ -62,22 +82,22 @@ protected sealed class VBuffer

public VBuffer()
{
_array = new int[FirstBufferLength];
MaxDepth = FirstBufferMaxDepth;
_array = new int[FirstSegmentLength];
MaxDepth = FirstSegmentMaxDepth;
}

public VBuffer(VBuffer previous)
{
Debug.Assert(previous != null);

var minDepth = previous.MaxDepth + 1;
var maxDepth = previous.MaxDepth * GrowFactor;
var maxDepth = (int)(previous.MaxDepth * GrowFactor);

Debug.Assert(minDepth > 0);
Debug.Assert(minDepth <= maxDepth);

Previous = previous;
_array = new int[GetNextBufferLength(minDepth - 1, maxDepth)];
_array = new int[GetNextSegmentLength(minDepth - 1, maxDepth)];
MinDepth = minDepth;
MaxDepth = maxDepth;

Expand All @@ -95,7 +115,7 @@ public VArray GetVArray(int depth)
}

public bool IsTooLargeToPool
=> _array.Length > MaxPooledBufferSize;
=> MaxDepth > PooledSegmentMaxDepthThreshold;

private static int GetVArrayLength(int depth)
=> 2 * Math.Max(depth, 1) + 1;
Expand All @@ -105,7 +125,7 @@ private static int GetVArrayStart(int depth)
=> (depth == 0) ? 0 : depth * depth + 2;

// Sum { d = previousChunkDepth..maxDepth : 2*d+1 } = (maxDepth + 1)^2 - precedingBufferMaxDepth^2
private static int GetNextBufferLength(int precedingBufferMaxDepth, int maxDepth)
private static int GetNextSegmentLength(int precedingBufferMaxDepth, int maxDepth)
=> (maxDepth + 1) * (maxDepth + 1) - precedingBufferMaxDepth * precedingBufferMaxDepth;

public void Unlink()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ private LongestCommonSubstring()
protected override bool ItemsEqual(string oldSequence, int oldIndex, string newSequence, int newIndex)
=> oldSequence[oldIndex] == newSequence[newIndex];

public static double ComputeDistance(string oldValue, string newValue)
=> s_instance.ComputeDistance(oldValue, oldValue.Length, newValue, newValue.Length);
public static double ComputePrefixDistance(string oldValue, int oldLength, string newValue, int newLength)
=> s_instance.ComputeDistance(oldValue, oldLength, newValue, newLength);

public static IEnumerable<SequenceEdit> GetEdits(string oldValue, string newValue)
=> s_instance.GetEdits(oldValue, oldValue.Length, newValue, newValue.Length);
Expand Down

0 comments on commit b404fbe

Please sign in to comment.