Skip to content

Commit

Permalink
Top k improved search (#25)
Browse files Browse the repository at this point in the history
* Update NativeEmbeddings.swift

* update

update

* test cases and improved topKsearch

* Update Sources/SimilaritySearchKit/Core/Embeddings/Models/NativeEmbeddings.swift

Co-authored-by: Zach Nagengast <[email protected]>

* Update Sources/SimilaritySearchKit/Core/Embeddings/Models/NativeEmbeddings.swift

Co-authored-by: Zach Nagengast <[email protected]>

* getDefault StoragePath is now public

I didnt see a reason why this function has to be private and i need to access the file location from outside.

* dokumentation and access to filePath

* added documentation

added documentation

* performance test case

* Update DistanceTest.swift

* Update formatting and naming

---------

Co-authored-by: BernhardEisvogel <[email protected]>
Co-authored-by: Zach Nagengast <[email protected]>
Co-authored-by: ZachNagengast <[email protected]>
  • Loading branch information
4 people authored Nov 14, 2023
1 parent 61a2e09 commit 6d78d30
Show file tree
Hide file tree
Showing 5 changed files with 170 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,9 @@ public protocol DistanceMetricProtocol {
/// Find the nearest neighbors given a query embedding vector and a list of embeddings vectors.
///
/// - Parameters:
/// - queryEmbedding: A `[Float]` array representing the query embedding vector.
/// - itemEmbeddings: A `[[Float]]` array representing the list of embeddings vectors to search within.
/// - resultsCount: An Int representing the number of nearest neighbors to return.
/// - queryEmbedding: A `[Float]` array representing the query embedding vector.
/// - itemEmbeddings: A `[[Float]]` array representing the list of embeddings vectors to search within.
/// - resultsCount: An Int representing the number of nearest neighbors to return.
///
/// - Returns: A `[(Float, Int)]` array, where each tuple contains the similarity score and the index of the corresponding item in `neighborEmbeddings`. The array is sorted by decreasing similarity ranking.
func findNearest(for queryEmbedding: [Float], in neighborEmbeddings: [[Float]], resultsCount: Int) -> [(Float, Int)]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,37 +76,49 @@ public struct EuclideanDistance: DistanceMetricProtocol {
/// Helper function to sort scores and return the top K scores with their indices.
///
/// - Parameters:
/// - scores: An array of Float values representing scores.
/// - topK: The number of top scores to return.
/// - scores: An array of Float values representing scores.
/// - topK: The number of top scores to return.
///
/// - Returns: An array of tuples containing the top K scores and their corresponding indices.
public func sortedScores(scores: [Float], topK: Int) -> [(Float, Int)] {
// Combine indices & scores
let indexedScores = scores.enumerated().map { index, score in (score, index) }

// Sort by decreasing score
let sortedIndexedScores = indexedScores.sorted { $0.0 > $1.0 }
func compare(a: (Float, Int), b: (Float, Int)) throws -> Bool {
return a.0 > b.0
}

// Take top k neighbors
let results = Array(sortedIndexedScores.prefix(topK))

return results
do {
return try indexedScores.topK(topK, by: compare)
} catch {
print("There has been an error comparing elements in sortedScores")
return []
}
}

/// Helper function to sort distances and return the top K distances with their indices.
///
/// - Parameters:
/// - distances: An array of Float values representing distances.
/// - topK: The number of top distances to return.
/// - distances: An array of Float values representing distances.
/// - topK: The number of top distances to return.
///
/// - Returns: An array of tuples containing the top K distances and their corresponding indices.
public func sortedDistances(distances: [Float], topK: Int) -> [(Float, Int)] {
// Combine indices & distances
let indexedDistances = distances.enumerated().map { index, score in (score, index) }

// Sort by increasing distance
let sortedIndexedDistances = indexedDistances.sorted { $0.0 < $1.0 }
func compare(a: (Float, Int), b: (Float, Int)) throws -> Bool {
return a.0 < b.0
}

// Take top k neighbors
let results = Array(sortedIndexedDistances.prefix(topK))

return results
do {
return try indexedDistances.topK(topK, by: compare)
} catch {
print("There has been an error comparing elements in sortedDistances")
return []
}
}
60 changes: 60 additions & 0 deletions Sources/SimilaritySearchKit/Core/Embeddings/Metrics/TopK.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
//
// TopK.swift
//
//
// Created by Bernhard Eisvogel on 31.10.23.
//

import Foundation

public extension Collection {
/// Helper function to sort distances and return the top K distances with their indices.
///
/// The `by` parameter accepts a function of the following form:
/// ```swift
/// (Element, Element) throws -> Bool
/// ```
///
/// Adapted from [Stackoverflow](https://stackoverflow.com/questions/65746299/how-do-you-find-the-top-3-maximum-values-in-a-swift-dictionary)
///
/// - Parameters:
/// - count: the number of top distances to return.
/// - by: comparison function
///
/// - Returns: ordered array containing the top K distances
///
/// - Note: TopK and the standard swift implementations switch elements with equal value differently
func topK(_ count: Int, by areInIncreasingOrder: (Element, Element) throws -> Bool) rethrows -> [Self.Element] {
assert(count >= 0,
"""
Cannot prefix with a negative amount of elements!
""")

guard count > 0 else {
return []
}

let prefixCount = Swift.min(count, self.count)

guard prefixCount < self.count / 10 else {
return try Array(sorted(by: areInIncreasingOrder).prefix(prefixCount))
}

var result = try self.prefix(prefixCount).sorted(by: areInIncreasingOrder)

for e in self.dropFirst(prefixCount) {
if let last = result.last, try areInIncreasingOrder(last, e) {
continue
}
let insertionIndex = try result.partition { try areInIncreasingOrder(e, $0) }
let isLastElement = insertionIndex == result.endIndex
result.removeLast()
if isLastElement {
result.append(e)
} else {
result.insert(e, at: insertionIndex)
}
}
return result
}
}
28 changes: 19 additions & 9 deletions Sources/SimilaritySearchKit/Core/Index/SimilarityIndex.swift
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,24 @@ extension SimilarityIndex {
}

public func loadIndex(fromDirectory path: URL? = nil, name: String? = nil) throws -> [IndexItem]? {
if let indexPath = try getIndexPath(fromDirectory: path, name: name) {
let loadedIndexItems = try vectorStore.loadIndex(from: indexPath)
addItems(loadedIndexItems)
print("Loaded \(indexItems.count) index items from \(indexPath.absoluteString)")
return loadedIndexItems
}

return nil
}

/// This function returns the default location where the data from the loadIndex/saveIndex functions gets stored
/// gets stored.
/// - Parameters:
/// - fromDirectory: optional directory path where the file postfix is added to
/// - name: optional name
///
/// - Returns: an optional URL
public func getIndexPath(fromDirectory path: URL? = nil, name: String? = nil) throws -> URL? {
let indexName = name ?? self.indexName
let basePath: URL

Expand All @@ -333,15 +351,7 @@ extension SimilarityIndex {
// Default local path
basePath = try getDefaultStoragePath()
}

if let vectorStorePath = vectorStore.listIndexes(at: basePath).first(where: { $0.lastPathComponent.contains(indexName) }) {
let loadedIndexItems = try vectorStore.loadIndex(from: vectorStorePath)
addItems(loadedIndexItems)
print("Loaded \(indexItems.count) index items from \(vectorStorePath.absoluteString)")
return loadedIndexItems
}

return nil
return vectorStore.listIndexes(at: basePath).first(where: { $0.lastPathComponent.contains(indexName) })
}

private func getDefaultStoragePath() throws -> URL {
Expand Down
64 changes: 64 additions & 0 deletions Tests/SimilaritySearchKitTests/DistanceTest.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
//
// DistanceTest.swift
//
//
// Created by Bernhard Eisvogel on 31.10.23.
//

@testable import SimilaritySearchKit
import XCTest

func randomString(_ length: Int) -> String {
let letters = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789아오우"
return String((0..<length).map { _ in letters.randomElement()! })
}

final class DistanceTest: XCTestCase {
private var randomStringData: [String] = {
var data: [String] = []
for _ in 0..<5000 {
data.append(randomString(Int.random(in: 7...20)))
}
return data
}()

private var k: Int = 10

func testExampleInt() throws {
let data = Array(0...10000).shuffled()

func sort(a: Int, b: Int) throws -> Bool {
return a<b
}

let topKcorrect = try Array(data.sorted(by: sort).prefix(k))
let topKfast = try data.topK(k, by: sort)
XCTAssertEqual(topKcorrect, topKfast)
}

func sortString(a: String, b: String) throws -> Bool {
return a.hashValue<b.hashValue
}

func testExampleStrSlow() {
// Measures the speed of the old algorithm
measure {
do {
_ = try Array(randomStringData.sorted(by: sortString).prefix(k))
} catch {
print("Error sorting with the old algorithm")
}
}
}

func testExampleStrFast() {
// Measures the speed of the new algorithm
measure {
do {
_ = try randomStringData.topK(k, by: sortString)
} catch {
print("Error sorting with the new algorithm")
}
}
}
}

0 comments on commit 6d78d30

Please sign in to comment.