Skip to content

Commit

Permalink
feat: add score threshold to MVI search (#516)
Browse files Browse the repository at this point in the history
Add score threshold to the MVI search method.

Update the protos and change 'distance' to 'score' in internal search
response.
  • Loading branch information
nand4011 authored Nov 11, 2023
1 parent 1c8c61f commit fc41d11
Show file tree
Hide file tree
Showing 5 changed files with 131 additions and 33 deletions.
58 changes: 31 additions & 27 deletions src/Momento.Sdk/IPreviewVectorIndexClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -156,32 +156,36 @@ public Task<UpsertItemBatchResponse> UpsertItemBatchAsync(string indexName,
///</returns>
public Task<DeleteItemBatchResponse> DeleteItemBatchAsync(string indexName, IEnumerable<string> ids);

/// <summary>
/// Searches for the most similar vectors to the query vector in the index.
/// Ranks the vectors according to the similarity metric specified when the
/// index was created.
/// </summary>
/// <param name="indexName">The name of the vector index to search in.</param>
/// <param name="queryVector">The vector to search for.</param>
/// <param name="topK">The number of results to return. Defaults to 10.</param>
/// <param name="metadataFields">A list of metadata fields to return with each result.</param>
/// <returns>
/// Task representing the result of the upsert operation. The
/// response object is resolved to a type-safe object of one of
/// the following subtypes:
/// <list type="bullet">
/// <item><description>SearchResponse.Success</description></item>
/// <item><description>SearchResponse.Error</description></item>
/// </list>
/// Pattern matching can be used to operate on the appropriate subtype.
/// For example:
/// <code>
/// if (response is SearchResponse.Error errorResponse)
/// {
/// // handle error as appropriate
/// }
/// </code>
///</returns>
/// <summary>
/// Searches for the most similar vectors to the query vector in the index.
/// Ranks the vectors according to the similarity metric specified when the
/// index was created.
/// </summary>
/// <param name="indexName">The name of the vector index to search in.</param>
/// <param name="queryVector">The vector to search for.</param>
/// <param name="topK">The number of results to return. Defaults to 10.</param>
/// <param name="metadataFields">A list of metadata fields to return with each result.</param>
/// <param name="scoreThreshold">A score threshold to filter results by. For cosine
/// similarity and inner product, scores lower than the threshold are excluded. For
/// euclidean similarity, scores higher than the threshold are excluded. The threshold
/// is exclusive. Defaults to None, ie no threshold.</param>
/// <returns>
/// Task representing the result of the upsert operation. The
/// response object is resolved to a type-safe object of one of
/// the following subtypes:
/// <list type="bullet">
/// <item><description>SearchResponse.Success</description></item>
/// <item><description>SearchResponse.Error</description></item>
/// </list>
/// Pattern matching can be used to operate on the appropriate subtype.
/// For example:
/// <code>
/// if (response is SearchResponse.Error errorResponse)
/// {
/// // handle error as appropriate
/// }
/// </code>
/// </returns>
public Task<SearchResponse> SearchAsync(string indexName, IEnumerable<float> queryVector, int topK = 10,
MetadataFields? metadataFields = null);
MetadataFields? metadataFields = null, float? scoreThreshold = null);
}
15 changes: 12 additions & 3 deletions src/Momento.Sdk/Internal/VectorIndexDataClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ public async Task<DeleteItemBatchResponse> DeleteItemBatchAsync(string indexName
}

public async Task<SearchResponse> SearchAsync(string indexName, IEnumerable<float> queryVector, int topK,
MetadataFields? metadataFields)
MetadataFields? metadataFields, float? scoreThreshold)
{
try
{
Expand All @@ -91,9 +91,18 @@ public async Task<SearchResponse> SearchAsync(string indexName, IEnumerable<floa
IndexName = indexName,
QueryVector = new _Vector { Elements = { queryVector } },
TopK = validatedTopK,
MetadataFields = metadataRequest
MetadataFields = metadataRequest,
};

if (scoreThreshold != null)
{
request.ScoreThreshold = scoreThreshold.Value;
}
else
{
request.NoScoreThreshold = new _NoScoreThreshold();
}

var response =
await grpcManager.Client.SearchAsync(request, new CallOptions(deadline: CalculateDeadline()));
var searchHits = response.Hits.Select(Convert).ToList();
Expand Down Expand Up @@ -167,7 +176,7 @@ private static MetadataValue Convert(_Metadata metadata)

private static SearchHit Convert(_SearchHit hit)
{
return new SearchHit(hit.Id, hit.Distance, Convert(hit.Metadata));
return new SearchHit(hit.Id, hit.Score, Convert(hit.Metadata));
}

private static void CheckValidIndexName(string indexName)
Expand Down
2 changes: 1 addition & 1 deletion src/Momento.Sdk/Momento.Sdk.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@
<ItemGroup>
<PackageReference Include="Grpc.Net.Client" Version="2.49.0" />
<PackageReference Include="Microsoft.Bcl.AsyncInterfaces" Version="7.0.0" />
<PackageReference Include="Momento.Protos" Version="0.91.1" />
<PackageReference Include="Momento.Protos" Version="0.94.1" />
<PackageReference Include="JWT" Version="9.0.3" />
<PackageReference Include="System.Threading.Channels" Version="6.0.0" />
<PackageReference Include="Microsoft.Extensions.Logging" Version="6.0.0" />
Expand Down
4 changes: 2 additions & 2 deletions src/Momento.Sdk/PreviewVectorIndexClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,9 @@ public async Task<DeleteItemBatchResponse> DeleteItemBatchAsync(string indexName

/// <inheritdoc />
public async Task<SearchResponse> SearchAsync(string indexName, IEnumerable<float> queryVector,
int topK = 10, MetadataFields? metadataFields = null)
int topK = 10, MetadataFields? metadataFields = null, float? searchThreshold = null)
{
return await dataClient.SearchAsync(indexName, queryVector, topK, metadataFields);
return await dataClient.SearchAsync(indexName, queryVector, topK, metadataFields, searchThreshold);
}

/// <inheritdoc />
Expand Down
85 changes: 85 additions & 0 deletions tests/Integration/Momento.Sdk.Tests/VectorIndexDataTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -317,4 +317,89 @@ public async Task UpsertAndSearch_WithDiverseMetadata()
await vectorIndexClient.DeleteIndexAsync(indexName);
}
}

public static IEnumerable<object[]> SearchThresholdTestCases =>
new List<object[]>
{
// similarity metric, scores, thresholds
new object[]
{
SimilarityMetric.CosineSimilarity,
new List<float> { 1.0f, 0.0f, -1.0f },
new List<float> { 0.5f, -1.01f, 1.0f }
},
new object[]
{
SimilarityMetric.InnerProduct,
new List<float> { 4.0f, 0.0f, -4.0f },
new List<float> { 0.0f, -4.01f, 4.0f }
},
new object[]
{
SimilarityMetric.EuclideanSimilarity,
new List<float> { 2.0f, 10.0f, 18.0f },
new List<float> { 3.0f, 20.0f, -0.01f }
}
};

[Theory]
[MemberData(nameof(SearchThresholdTestCases))]
public async Task Search_PruneBasedOnThreshold(SimilarityMetric similarityMetric, List<float> scores,
List<float> thresholds)
{
var indexName = $"index-{Utils.NewGuidString()}";

var createResponse = await vectorIndexClient.CreateIndexAsync(indexName, 2, similarityMetric);
Assert.True(createResponse is CreateIndexResponse.Success, $"Unexpected response: {createResponse}");

try
{
var upsertResponse = await vectorIndexClient.UpsertItemBatchAsync(indexName, new List<Item>
{
new("test_item_1", new List<float> { 1.0f, 1.0f }),
new("test_item_2", new List<float> { -1.0f, 1.0f }),
new("test_item_3", new List<float> { -1.0f, -1.0f })
});
Assert.True(upsertResponse is UpsertItemBatchResponse.Success,
$"Unexpected response: {upsertResponse}");

await Task.Delay(2_000);

var queryVector = new List<float> { 2.0f, 2.0f };
var searchHits = new List<SearchHit>
{
new("test_item_1", scores[0]),
new("test_item_2", scores[1]),
new("test_item_3", scores[2])
};

// Test threshold to get only the top result
var searchResponse =
await vectorIndexClient.SearchAsync(indexName, queryVector, 3, scoreThreshold: thresholds[0]);
Assert.True(searchResponse is SearchResponse.Success, $"Unexpected response: {searchResponse}");
var successResponse = (SearchResponse.Success)searchResponse;
Assert.Equal(new List<SearchHit>
{
searchHits[0]
}, successResponse.Hits);

// Test threshold to get all results
searchResponse =
await vectorIndexClient.SearchAsync(indexName, queryVector, 3, scoreThreshold: thresholds[1]);
Assert.True(searchResponse is SearchResponse.Success, $"Unexpected response: {searchResponse}");
successResponse = (SearchResponse.Success)searchResponse;
Assert.Equal(searchHits, successResponse.Hits);

// Test threshold to get no results
searchResponse =
await vectorIndexClient.SearchAsync(indexName, queryVector, 3, scoreThreshold: thresholds[2]);
Assert.True(searchResponse is SearchResponse.Success, $"Unexpected response: {searchResponse}");
successResponse = (SearchResponse.Success)searchResponse;
Assert.Empty(successResponse.Hits);
}
finally
{
await vectorIndexClient.DeleteIndexAsync(indexName);
}
}
}

0 comments on commit fc41d11

Please sign in to comment.