From 7bf84238c1d98d159063385a7be34620b8c640d2 Mon Sep 17 00:00:00 2001 From: buhe Date: Tue, 13 Feb 2024 12:14:54 +0800 Subject: [PATCH 1/3] "Refactor SimilarityIndex initialization and setupDimension method for improved efficiency and clarity." --- .../Core/Index/SimilarityIndex.swift | 8 +++---- .../BenchmarkTests.swift | 7 +++--- .../SimilaritySearchKitTests.swift | 24 +++++++++---------- 3 files changed, 19 insertions(+), 20 deletions(-) diff --git a/Sources/SimilaritySearchKit/Core/Index/SimilarityIndex.swift b/Sources/SimilaritySearchKit/Core/Index/SimilarityIndex.swift index cb1559e..467d618 100644 --- a/Sources/SimilaritySearchKit/Core/Index/SimilarityIndex.swift +++ b/Sources/SimilaritySearchKit/Core/Index/SimilarityIndex.swift @@ -101,18 +101,16 @@ public class SimilarityIndex { // MARK: - Initializers - public init(name: String? = nil, model: (any EmbeddingsProtocol)? = nil, metric: (any DistanceMetricProtocol)? = nil, vectorStore: (any VectorStoreProtocol)? = nil) async { + public init(name: String? = nil, model: (any EmbeddingsProtocol)? = nil, metric: (any DistanceMetricProtocol)? = nil, vectorStore: (any VectorStoreProtocol)? = nil) { // Setup index with defaults self.indexName = name ?? "SimilaritySearchKitIndex" self.indexModel = model ?? NativeEmbeddings() self.indexMetric = metric ?? CosineSimilarity() self.vectorStore = vectorStore ?? JsonStore() - - // Run the model once to discover dimention size - await setupDimension() } - private func setupDimension() async { + // Run the model once to discover dimention size + public func setupDimension() async { if let testVector = await indexModel.encode(sentence: "Test sentence") { dimension = testVector.count } else { diff --git a/Tests/SimilaritySearchKitTests/BenchmarkTests.swift b/Tests/SimilaritySearchKitTests/BenchmarkTests.swift index e181a25..f1004c5 100644 --- a/Tests/SimilaritySearchKitTests/BenchmarkTests.swift +++ b/Tests/SimilaritySearchKitTests/BenchmarkTests.swift @@ -41,7 +41,8 @@ class BenchmarkTests: XCTestCase { let expectation = XCTestExpectation(description: "Encoding passage texts") Task { - let similarityIndex = await SimilarityIndex(model: DistilbertEmbeddings()) + let similarityIndex = SimilarityIndex(model: DistilbertEmbeddings()) + await similarityIndex.setupDimension() await similarityIndex.addItems( ids: [UUID().uuidString], texts: [searchPassage.text], @@ -125,8 +126,8 @@ class BenchmarkTests: XCTestCase { Task { print("\nGenerating similarity index for \(testAmount) passages") - let similarityIndex = await SimilarityIndex(model: DistilbertEmbeddings()) - + let similarityIndex = SimilarityIndex(model: DistilbertEmbeddings()) + await similarityIndex.setupDimension() var startTime = CFAbsoluteTimeGetCurrent() await similarityIndex.addItems( ids: passageIds, diff --git a/Tests/SimilaritySearchKitTests/SimilaritySearchKitTests.swift b/Tests/SimilaritySearchKitTests/SimilaritySearchKitTests.swift index c908040..eb271ed 100644 --- a/Tests/SimilaritySearchKitTests/SimilaritySearchKitTests.swift +++ b/Tests/SimilaritySearchKitTests/SimilaritySearchKitTests.swift @@ -21,8 +21,8 @@ class SimilaritySearchKitTests: XCTestCase { } func testSavingJsonIndex() async { - let similarityIndex = await SimilarityIndex(model: DistilbertEmbeddings(), vectorStore: JsonStore()) - + let similarityIndex = SimilarityIndex(model: DistilbertEmbeddings(), vectorStore: JsonStore()) + await similarityIndex.setupDimension() await similarityIndex.addItem(id: "1", text: "Example text", metadata: ["source": "test source"], embedding: [0.1, 0.2, 0.3]) let successPath = try! similarityIndex.saveIndex(name: "TestIndexForSaving") @@ -31,24 +31,24 @@ class SimilaritySearchKitTests: XCTestCase { } func testLoadingJsonIndex() async { - let similarityIndex = await SimilarityIndex(model: DistilbertEmbeddings(), vectorStore: JsonStore()) - + let similarityIndex = SimilarityIndex(model: DistilbertEmbeddings(), vectorStore: JsonStore()) + await similarityIndex.setupDimension() await similarityIndex.addItem(id: "1", text: "Example text", metadata: ["source": "test source"]) let successPath = try! similarityIndex.saveIndex(name: "TestIndexForLoading") XCTAssertNotNil(successPath) - let similarityIndex2 = await SimilarityIndex(model: DistilbertEmbeddings(), vectorStore: JsonStore()) - + let similarityIndex2 = SimilarityIndex(model: DistilbertEmbeddings(), vectorStore: JsonStore()) + await similarityIndex2.setupDimension() let loadedItems = try! similarityIndex2.loadIndex(name: "TestIndexForLoading") XCTAssertNotNil(loadedItems) } func testSavingBinaryIndex() async { - let similarityIndex = await SimilarityIndex(model: DistilbertEmbeddings(), vectorStore: BinaryStore()) - + let similarityIndex = SimilarityIndex(model: DistilbertEmbeddings(), vectorStore: BinaryStore()) + await similarityIndex.setupDimension() await similarityIndex.addItem(id: "1", text: "Example text", metadata: ["source": "test source"], embedding: [0.1, 0.2, 0.3]) let successPath = try! similarityIndex.saveIndex(name: "TestIndexForSaving") @@ -57,16 +57,16 @@ class SimilaritySearchKitTests: XCTestCase { } func testLoadingBinaryIndex() async { - let similarityIndex = await SimilarityIndex(model: DistilbertEmbeddings(), vectorStore: BinaryStore()) - + let similarityIndex = SimilarityIndex(model: DistilbertEmbeddings(), vectorStore: BinaryStore()) + await similarityIndex.setupDimension() await similarityIndex.addItem(id: "1", text: "Example text", metadata: ["source": "test source"]) let successPath = try! similarityIndex.saveIndex(name: "TestIndexForLoading") XCTAssertNotNil(successPath) - let similarityIndex2 = await SimilarityIndex(model: DistilbertEmbeddings(), vectorStore: BinaryStore()) - + let similarityIndex2 = SimilarityIndex(model: DistilbertEmbeddings(), vectorStore: BinaryStore()) + await similarityIndex.setupDimension() let loadedItems = try! similarityIndex2.loadIndex(name: "TestIndexForLoading") XCTAssertNotNil(loadedItems) From b00ebd85bd72c2fcc59eb97dc534ba23a04ca031 Mon Sep 17 00:00:00 2001 From: buhe Date: Wed, 14 Feb 2024 11:34:15 +0800 Subject: [PATCH 2/3] "Refactor loadIndex method to remove unnecessary closure and improve code readability." --- .../SimilaritySearchKit/Core/Index/SimilarityIndex.swift | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/Sources/SimilaritySearchKit/Core/Index/SimilarityIndex.swift b/Sources/SimilaritySearchKit/Core/Index/SimilarityIndex.swift index 467d618..0b09b4b 100644 --- a/Sources/SimilaritySearchKit/Core/Index/SimilarityIndex.swift +++ b/Sources/SimilaritySearchKit/Core/Index/SimilarityIndex.swift @@ -325,9 +325,9 @@ extension SimilarityIndex { public func loadIndex(fromDirectory path: URL? = nil, name: String? = nil) throws -> [IndexItem]? { if let indexPath = try getIndexPath(fromDirectory: path, name: name) { let loadedIndexItems = try vectorStore.loadIndex(from: indexPath) - addItems(loadedIndexItems) {[self] in - print("Loaded \(indexItems.count) index items from \(indexPath.absoluteString)") - } +// addItems(loadedIndexItems) {[self] in + print("Loaded \(indexItems.count) index items from \(indexPath.absoluteString)") +// } return loadedIndexItems } From 39bbd0b5c2662728792ee9be132212e5fd4a49e2 Mon Sep 17 00:00:00 2001 From: buhe Date: Wed, 14 Feb 2024 11:47:58 +0800 Subject: [PATCH 3/3] "Refactor loadIndex method to assign loaded index items directly to indexItems variable." --- .../SimilaritySearchKit/Core/Index/SimilarityIndex.swift | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/Sources/SimilaritySearchKit/Core/Index/SimilarityIndex.swift b/Sources/SimilaritySearchKit/Core/Index/SimilarityIndex.swift index 0b09b4b..7a590d2 100644 --- a/Sources/SimilaritySearchKit/Core/Index/SimilarityIndex.swift +++ b/Sources/SimilaritySearchKit/Core/Index/SimilarityIndex.swift @@ -324,11 +324,9 @@ extension SimilarityIndex { public func loadIndex(fromDirectory path: URL? = nil, name: String? = nil) throws -> [IndexItem]? { if let indexPath = try getIndexPath(fromDirectory: path, name: name) { - let loadedIndexItems = try vectorStore.loadIndex(from: indexPath) -// addItems(loadedIndexItems) {[self] in + indexItems = try vectorStore.loadIndex(from: indexPath) print("Loaded \(indexItems.count) index items from \(indexPath.absoluteString)") -// } - return loadedIndexItems + return indexItems } return nil