-
-
Notifications
You must be signed in to change notification settings - Fork 35
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Update NativeEmbeddings.swift * update update * test cases and improved topKsearch * Update Sources/SimilaritySearchKit/Core/Embeddings/Models/NativeEmbeddings.swift Co-authored-by: Zach Nagengast <[email protected]> * Update Sources/SimilaritySearchKit/Core/Embeddings/Models/NativeEmbeddings.swift Co-authored-by: Zach Nagengast <[email protected]> * getDefault StoragePath is now public I didnt see a reason why this function has to be private and i need to access the file location from outside. * dokumentation and access to filePath * added documentation added documentation * performance test case * Update DistanceTest.swift * Update formatting and naming --------- Co-authored-by: BernhardEisvogel <[email protected]> Co-authored-by: Zach Nagengast <[email protected]> Co-authored-by: ZachNagengast <[email protected]>
- Loading branch information
1 parent
61a2e09
commit 6d78d30
Showing
5 changed files
with
170 additions
and
24 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
60 changes: 60 additions & 0 deletions
60
Sources/SimilaritySearchKit/Core/Embeddings/Metrics/TopK.swift
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
// | ||
// TopK.swift | ||
// | ||
// | ||
// Created by Bernhard Eisvogel on 31.10.23. | ||
// | ||
|
||
import Foundation | ||
|
||
public extension Collection { | ||
/// Helper function to sort distances and return the top K distances with their indices. | ||
/// | ||
/// The `by` parameter accepts a function of the following form: | ||
/// ```swift | ||
/// (Element, Element) throws -> Bool | ||
/// ``` | ||
/// | ||
/// Adapted from [Stackoverflow](https://stackoverflow.com/questions/65746299/how-do-you-find-the-top-3-maximum-values-in-a-swift-dictionary) | ||
/// | ||
/// - Parameters: | ||
/// - count: the number of top distances to return. | ||
/// - by: comparison function | ||
/// | ||
/// - Returns: ordered array containing the top K distances | ||
/// | ||
/// - Note: TopK and the standard swift implementations switch elements with equal value differently | ||
func topK(_ count: Int, by areInIncreasingOrder: (Element, Element) throws -> Bool) rethrows -> [Self.Element] { | ||
assert(count >= 0, | ||
""" | ||
Cannot prefix with a negative amount of elements! | ||
""") | ||
|
||
guard count > 0 else { | ||
return [] | ||
} | ||
|
||
let prefixCount = Swift.min(count, self.count) | ||
|
||
guard prefixCount < self.count / 10 else { | ||
return try Array(sorted(by: areInIncreasingOrder).prefix(prefixCount)) | ||
} | ||
|
||
var result = try self.prefix(prefixCount).sorted(by: areInIncreasingOrder) | ||
|
||
for e in self.dropFirst(prefixCount) { | ||
if let last = result.last, try areInIncreasingOrder(last, e) { | ||
continue | ||
} | ||
let insertionIndex = try result.partition { try areInIncreasingOrder(e, $0) } | ||
let isLastElement = insertionIndex == result.endIndex | ||
result.removeLast() | ||
if isLastElement { | ||
result.append(e) | ||
} else { | ||
result.insert(e, at: insertionIndex) | ||
} | ||
} | ||
return result | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
// | ||
// DistanceTest.swift | ||
// | ||
// | ||
// Created by Bernhard Eisvogel on 31.10.23. | ||
// | ||
|
||
@testable import SimilaritySearchKit | ||
import XCTest | ||
|
||
func randomString(_ length: Int) -> String { | ||
let letters = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789아오우" | ||
return String((0..<length).map { _ in letters.randomElement()! }) | ||
} | ||
|
||
final class DistanceTest: XCTestCase { | ||
private var randomStringData: [String] = { | ||
var data: [String] = [] | ||
for _ in 0..<5000 { | ||
data.append(randomString(Int.random(in: 7...20))) | ||
} | ||
return data | ||
}() | ||
|
||
private var k: Int = 10 | ||
|
||
func testExampleInt() throws { | ||
let data = Array(0...10000).shuffled() | ||
|
||
func sort(a: Int, b: Int) throws -> Bool { | ||
return a<b | ||
} | ||
|
||
let topKcorrect = try Array(data.sorted(by: sort).prefix(k)) | ||
let topKfast = try data.topK(k, by: sort) | ||
XCTAssertEqual(topKcorrect, topKfast) | ||
} | ||
|
||
func sortString(a: String, b: String) throws -> Bool { | ||
return a.hashValue<b.hashValue | ||
} | ||
|
||
func testExampleStrSlow() { | ||
// Measures the speed of the old algorithm | ||
measure { | ||
do { | ||
_ = try Array(randomStringData.sorted(by: sortString).prefix(k)) | ||
} catch { | ||
print("Error sorting with the old algorithm") | ||
} | ||
} | ||
} | ||
|
||
func testExampleStrFast() { | ||
// Measures the speed of the new algorithm | ||
measure { | ||
do { | ||
_ = try randomStringData.topK(k, by: sortString) | ||
} catch { | ||
print("Error sorting with the new algorithm") | ||
} | ||
} | ||
} | ||
} |