From b9d4088d3c464ab8e88e1b475e15b29fc4b31e89 Mon Sep 17 00:00:00 2001
From: Mohamed Afifi <mohamede1945@gmail.com>
Date: Sun, 26 Nov 2023 14:50:07 -0500
Subject: [PATCH] Fix crash when getOnGoingDownloads is called early

Sometimes `DownloadManager.getOnGoingDownloads` may be called earlier
than `DownloadManager.start`. This result in a crash.

This change ensures that getOnGoingDownloads awaits the initialization.
---
 .../Sources/Features/AsyncInitializer.swift   | 39 +++++++++++++++++++
 .../DownloadBatchDataController.swift         | 32 ++++++++-------
 .../Sources/Downloader/DownloadManager.swift  | 26 +++++++------
 .../Tests/DownloadManagerTests.swift          | 18 +++++++++
 .../BatchDownloaderFake.swift                 | 26 +++++++++++--
 5 files changed, 111 insertions(+), 30 deletions(-)
 create mode 100644 Core/Utilities/Sources/Features/AsyncInitializer.swift

diff --git a/Core/Utilities/Sources/Features/AsyncInitializer.swift b/Core/Utilities/Sources/Features/AsyncInitializer.swift
new file mode 100644
index 00000000..dcf864ef
--- /dev/null
+++ b/Core/Utilities/Sources/Features/AsyncInitializer.swift
@@ -0,0 +1,39 @@
+//
+//  AsyncInitializer.swift
+//
+//
+//  Created by Mohamed Afifi on 2023-11-26.
+//
+
+public struct AsyncInitializer {
+    // MARK: Lifecycle
+
+    public init() {
+        var continuation: AsyncStream<Void>.Continuation!
+        let stream = AsyncStream<Void> { continuation = $0 }
+        self.continuation = continuation
+        self.stream = stream
+    }
+
+    // MARK: Public
+
+    public private(set) var initialized = false
+
+    public mutating func initialize() {
+        initialized = true
+        continuation.finish()
+    }
+
+    public func awaitInitialization() async {
+        if initialized {
+            return
+        }
+        // Wait until the stream finishes
+        for await _ in stream {}
+    }
+
+    // MARK: Private
+
+    private let continuation: AsyncStream<Void>.Continuation
+    private let stream: AsyncStream<Void>
+}
diff --git a/Data/BatchDownloader/Sources/Downloader/DownloadBatchDataController.swift b/Data/BatchDownloader/Sources/Downloader/DownloadBatchDataController.swift
index 97913e1d..d5419b32 100644
--- a/Data/BatchDownloader/Sources/Downloader/DownloadBatchDataController.swift
+++ b/Data/BatchDownloader/Sources/Downloader/DownloadBatchDataController.swift
@@ -43,31 +43,23 @@ actor DownloadBatchDataController {
 
     // MARK: Internal
 
-    func bootstrapPersistence() async {
-        do {
-            try await attempt(times: 3) {
-                try await loadBatchesFromPersistence()
-            }
-        } catch {
-            crasher.recordError(error, reason: "Failed to retrieve initial download batches from persistence.")
-        }
-    }
-
     func start(with session: NetworkSession) async {
+        await bootstrapPersistence()
+
         self.session = session
         let (_, _, downloadTasks) = await session.tasks()
         for batch in batches {
             await batch.associateTasks(downloadTasks)
         }
 
-        loadedInitialRunningTasks = true
+        initialRunningTasks.initialize()
 
         // start pending tasks if needed
         await startPendingTasksIfNeeded()
     }
 
-    func getOnGoingDownloads() -> [DownloadBatchResponse] {
-        precondition(loadedInitialRunningTasks)
+    func getOnGoingDownloads() async -> [DownloadBatchResponse] {
+        await initialRunningTasks.awaitInitialization()
         return Array(batches)
     }
 
@@ -118,7 +110,7 @@ actor DownloadBatchDataController {
 
     private var batches: Set<DownloadBatchResponse> = []
 
-    private var loadedInitialRunningTasks = false
+    private var initialRunningTasks = AsyncInitializer()
 
     private var runningTasks: Int {
         get async {
@@ -139,6 +131,16 @@ actor DownloadBatchDataController {
         }
     }
 
+    private func bootstrapPersistence() async {
+        do {
+            try await attempt(times: 3) {
+                try await loadBatchesFromPersistence()
+            }
+        } catch {
+            crasher.recordError(error, reason: "Failed to retrieve initial download batches from persistence.")
+        }
+    }
+
     private func loadBatchesFromPersistence() async throws {
         let batches = try await persistence.retrieveAll()
         logger.info("Loading \(batches.count) from persistence")
@@ -172,7 +174,7 @@ actor DownloadBatchDataController {
     }
 
     private func startPendingTasksIfNeeded() async {
-        if !loadedInitialRunningTasks {
+        if !initialRunningTasks.initialized {
             return
         }
 
diff --git a/Data/BatchDownloader/Sources/Downloader/DownloadManager.swift b/Data/BatchDownloader/Sources/Downloader/DownloadManager.swift
index 37e364c8..2a0f28ea 100644
--- a/Data/BatchDownloader/Sources/Downloader/DownloadManager.swift
+++ b/Data/BatchDownloader/Sources/Downloader/DownloadManager.swift
@@ -67,17 +67,7 @@ public final class DownloadManager: Sendable {
 
     public func start() async {
         logger.info("Starting download manager")
-        let operationQueue = OperationQueue()
-        operationQueue.name = "com.quran.downloads"
-        operationQueue.maxConcurrentOperationCount = 1
-
-        let dispatchQueue = DispatchQueue(label: "com.quran.downloads.dispatch")
-        operationQueue.underlyingQueue = dispatchQueue
-
-        await dataController.bootstrapPersistence()
-
-        let session = sessionFactory(handler, operationQueue)
-        self.session = session
+        let session = createSession()
         await dataController.start(with: session)
         logger.info("Download manager start completed")
     }
@@ -101,4 +91,18 @@ public final class DownloadManager: Sendable {
     private var session: NetworkSession?
     private let handler: DownloadSessionDelegate
     private let dataController: DownloadBatchDataController
+
+    private func createSession() -> NetworkSession {
+        let operationQueue = OperationQueue()
+        operationQueue.name = "com.quran.downloads"
+        operationQueue.maxConcurrentOperationCount = 1
+
+        let dispatchQueue = DispatchQueue(label: "com.quran.downloads.dispatch")
+        operationQueue.underlyingQueue = dispatchQueue
+
+        let session = sessionFactory(handler, operationQueue)
+        self.session = session
+
+        return session
+    }
 }
diff --git a/Data/BatchDownloader/Tests/DownloadManagerTests.swift b/Data/BatchDownloader/Tests/DownloadManagerTests.swift
index 9d04b16f..69bab2a7 100644
--- a/Data/BatchDownloader/Tests/DownloadManagerTests.swift
+++ b/Data/BatchDownloader/Tests/DownloadManagerTests.swift
@@ -67,6 +67,24 @@ final class DownloadManagerTests: XCTestCase {
         XCTAssertEqual(calls.calls, 1)
     }
 
+    func test_onGoingDownloads_whileStartNotFinished() async throws {
+        // Load a single batch
+        let batch = DownloadBatchRequest(requests: [request1])
+        _ = try await downloader.download(batch)
+
+        // Deallocate downloader & create new one
+        downloader = nil
+        downloader = await BatchDownloaderFake.makeDownloaderDontWaitForSession()
+
+        // Test calling getOnGoingDownloads and start at the same time.
+        async let startTask: () = await downloader.start()
+        async let downloadsTask = await downloader.getOnGoingDownloads()
+        let (downloads, _) = await (downloadsTask, startTask)
+
+        // Verify
+        XCTAssertEqual(downloads.count, 1)
+    }
+
     func testLoadingOnGoingDownload() async throws {
         let emptyDownloads = await downloader.getOnGoingDownloads()
         XCTAssertEqual(emptyDownloads.count, 0)
diff --git a/Data/BatchDownloaderFake/BatchDownloaderFake.swift b/Data/BatchDownloaderFake/BatchDownloaderFake.swift
index c65e8ee5..95d8290f 100644
--- a/Data/BatchDownloaderFake/BatchDownloaderFake.swift
+++ b/Data/BatchDownloaderFake/BatchDownloaderFake.swift
@@ -20,10 +20,7 @@ public enum BatchDownloaderFake {
     public static let downloadsURL = RelativeFilePath(downloads, isDirectory: true)
 
     public static func makeDownloader(downloads: [SessionTask] = [], fileManager: FileSystem = DefaultFileSystem()) async -> (DownloadManager, NetworkSessionFake) {
-        try? DefaultFileSystem().createDirectory(at: Self.downloadsURL, withIntermediateDirectories: true)
-        let downloadsDBPath = Self.downloadsURL.appendingPathComponent("ongoing-downloads.db", isDirectory: false)
-
-        let persistence = GRDBDownloadsPersistence(fileURL: downloadsDBPath.url)
+        let persistence = makeDownloadsPersistence()
         actor SessionActor {
             var session: NetworkSessionFake!
             let channel = AsyncChannel<Void>()
@@ -50,6 +47,20 @@ public enum BatchDownloaderFake {
         return (downloader, await sessionActor.session)
     }
 
+    public static func makeDownloaderDontWaitForSession(downloads: [SessionTask] = [], fileManager: FileSystem = DefaultFileSystem()) async -> DownloadManager {
+        let persistence = makeDownloadsPersistence()
+        let downloader = DownloadManager(
+            maxSimultaneousDownloads: maxSimultaneousDownloads,
+            sessionFactory: { delegate, queue in
+                let session = NetworkSessionFake(queue: queue, delegate: delegate, downloads: downloads)
+                return session
+            },
+            persistence: persistence,
+            fileManager: fileManager
+        )
+        return downloader
+    }
+
     public static func tearDown() {
         try? FileManager.default.removeItem(at: Self.downloadsURL)
     }
@@ -73,4 +84,11 @@ public enum BatchDownloaderFake {
     // MARK: Private
 
     private static let downloads = "downloads"
+
+    private static func makeDownloadsPersistence() -> GRDBDownloadsPersistence {
+        try? DefaultFileSystem().createDirectory(at: Self.downloadsURL, withIntermediateDirectories: true)
+        let downloadsDBPath = Self.downloadsURL.appendingPathComponent("ongoing-downloads.db", isDirectory: false)
+        let persistence = GRDBDownloadsPersistence(fileURL: downloadsDBPath.url)
+        return persistence
+    }
 }