Skip to content

Commit 4d35452

Browse files
authored
Fix crash when getOnGoingDownloads is called early (#595)
Sometimes `DownloadManager.getOnGoingDownloads` may be called earlier than `DownloadManager.start`. This result in a crash. This change ensures that getOnGoingDownloads awaits the initialization.
1 parent 4f1a562 commit 4d35452

File tree

5 files changed

+111
-30
lines changed

5 files changed

+111
-30
lines changed
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
//
2+
// AsyncInitializer.swift
3+
//
4+
//
5+
// Created by Mohamed Afifi on 2023-11-26.
6+
//
7+
8+
public struct AsyncInitializer {
9+
// MARK: Lifecycle
10+
11+
public init() {
12+
var continuation: AsyncStream<Void>.Continuation!
13+
let stream = AsyncStream<Void> { continuation = $0 }
14+
self.continuation = continuation
15+
self.stream = stream
16+
}
17+
18+
// MARK: Public
19+
20+
public private(set) var initialized = false
21+
22+
public mutating func initialize() {
23+
initialized = true
24+
continuation.finish()
25+
}
26+
27+
public func awaitInitialization() async {
28+
if initialized {
29+
return
30+
}
31+
// Wait until the stream finishes
32+
for await _ in stream {}
33+
}
34+
35+
// MARK: Private
36+
37+
private let continuation: AsyncStream<Void>.Continuation
38+
private let stream: AsyncStream<Void>
39+
}

Data/BatchDownloader/Sources/Downloader/DownloadBatchDataController.swift

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -43,31 +43,23 @@ actor DownloadBatchDataController {
4343

4444
// MARK: Internal
4545

46-
func bootstrapPersistence() async {
47-
do {
48-
try await attempt(times: 3) {
49-
try await loadBatchesFromPersistence()
50-
}
51-
} catch {
52-
crasher.recordError(error, reason: "Failed to retrieve initial download batches from persistence.")
53-
}
54-
}
55-
5646
func start(with session: NetworkSession) async {
47+
await bootstrapPersistence()
48+
5749
self.session = session
5850
let (_, _, downloadTasks) = await session.tasks()
5951
for batch in batches {
6052
await batch.associateTasks(downloadTasks)
6153
}
6254

63-
loadedInitialRunningTasks = true
55+
initialRunningTasks.initialize()
6456

6557
// start pending tasks if needed
6658
await startPendingTasksIfNeeded()
6759
}
6860

69-
func getOnGoingDownloads() -> [DownloadBatchResponse] {
70-
precondition(loadedInitialRunningTasks)
61+
func getOnGoingDownloads() async -> [DownloadBatchResponse] {
62+
await initialRunningTasks.awaitInitialization()
7163
return Array(batches)
7264
}
7365

@@ -118,7 +110,7 @@ actor DownloadBatchDataController {
118110

119111
private var batches: Set<DownloadBatchResponse> = []
120112

121-
private var loadedInitialRunningTasks = false
113+
private var initialRunningTasks = AsyncInitializer()
122114

123115
private var runningTasks: Int {
124116
get async {
@@ -139,6 +131,16 @@ actor DownloadBatchDataController {
139131
}
140132
}
141133

134+
private func bootstrapPersistence() async {
135+
do {
136+
try await attempt(times: 3) {
137+
try await loadBatchesFromPersistence()
138+
}
139+
} catch {
140+
crasher.recordError(error, reason: "Failed to retrieve initial download batches from persistence.")
141+
}
142+
}
143+
142144
private func loadBatchesFromPersistence() async throws {
143145
let batches = try await persistence.retrieveAll()
144146
logger.info("Loading \(batches.count) from persistence")
@@ -172,7 +174,7 @@ actor DownloadBatchDataController {
172174
}
173175

174176
private func startPendingTasksIfNeeded() async {
175-
if !loadedInitialRunningTasks {
177+
if !initialRunningTasks.initialized {
176178
return
177179
}
178180

Data/BatchDownloader/Sources/Downloader/DownloadManager.swift

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -67,17 +67,7 @@ public final class DownloadManager: Sendable {
6767

6868
public func start() async {
6969
logger.info("Starting download manager")
70-
let operationQueue = OperationQueue()
71-
operationQueue.name = "com.quran.downloads"
72-
operationQueue.maxConcurrentOperationCount = 1
73-
74-
let dispatchQueue = DispatchQueue(label: "com.quran.downloads.dispatch")
75-
operationQueue.underlyingQueue = dispatchQueue
76-
77-
await dataController.bootstrapPersistence()
78-
79-
let session = sessionFactory(handler, operationQueue)
80-
self.session = session
70+
let session = createSession()
8171
await dataController.start(with: session)
8272
logger.info("Download manager start completed")
8373
}
@@ -101,4 +91,18 @@ public final class DownloadManager: Sendable {
10191
private var session: NetworkSession?
10292
private let handler: DownloadSessionDelegate
10393
private let dataController: DownloadBatchDataController
94+
95+
private func createSession() -> NetworkSession {
96+
let operationQueue = OperationQueue()
97+
operationQueue.name = "com.quran.downloads"
98+
operationQueue.maxConcurrentOperationCount = 1
99+
100+
let dispatchQueue = DispatchQueue(label: "com.quran.downloads.dispatch")
101+
operationQueue.underlyingQueue = dispatchQueue
102+
103+
let session = sessionFactory(handler, operationQueue)
104+
self.session = session
105+
106+
return session
107+
}
104108
}

Data/BatchDownloader/Tests/DownloadManagerTests.swift

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,24 @@ final class DownloadManagerTests: XCTestCase {
6767
XCTAssertEqual(calls.calls, 1)
6868
}
6969

70+
func test_onGoingDownloads_whileStartNotFinished() async throws {
71+
// Load a single batch
72+
let batch = DownloadBatchRequest(requests: [request1])
73+
_ = try await downloader.download(batch)
74+
75+
// Deallocate downloader & create new one
76+
downloader = nil
77+
downloader = await BatchDownloaderFake.makeDownloaderDontWaitForSession()
78+
79+
// Test calling getOnGoingDownloads and start at the same time.
80+
async let startTask: () = await downloader.start()
81+
async let downloadsTask = await downloader.getOnGoingDownloads()
82+
let (downloads, _) = await (downloadsTask, startTask)
83+
84+
// Verify
85+
XCTAssertEqual(downloads.count, 1)
86+
}
87+
7088
func testLoadingOnGoingDownload() async throws {
7189
let emptyDownloads = await downloader.getOnGoingDownloads()
7290
XCTAssertEqual(emptyDownloads.count, 0)

Data/BatchDownloaderFake/BatchDownloaderFake.swift

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,7 @@ public enum BatchDownloaderFake {
2020
public static let downloadsURL = RelativeFilePath(downloads, isDirectory: true)
2121

2222
public static func makeDownloader(downloads: [SessionTask] = [], fileManager: FileSystem = DefaultFileSystem()) async -> (DownloadManager, NetworkSessionFake) {
23-
try? DefaultFileSystem().createDirectory(at: Self.downloadsURL, withIntermediateDirectories: true)
24-
let downloadsDBPath = Self.downloadsURL.appendingPathComponent("ongoing-downloads.db", isDirectory: false)
25-
26-
let persistence = GRDBDownloadsPersistence(fileURL: downloadsDBPath.url)
23+
let persistence = makeDownloadsPersistence()
2724
actor SessionActor {
2825
var session: NetworkSessionFake!
2926
let channel = AsyncChannel<Void>()
@@ -50,6 +47,20 @@ public enum BatchDownloaderFake {
5047
return (downloader, await sessionActor.session)
5148
}
5249

50+
public static func makeDownloaderDontWaitForSession(downloads: [SessionTask] = [], fileManager: FileSystem = DefaultFileSystem()) async -> DownloadManager {
51+
let persistence = makeDownloadsPersistence()
52+
let downloader = DownloadManager(
53+
maxSimultaneousDownloads: maxSimultaneousDownloads,
54+
sessionFactory: { delegate, queue in
55+
let session = NetworkSessionFake(queue: queue, delegate: delegate, downloads: downloads)
56+
return session
57+
},
58+
persistence: persistence,
59+
fileManager: fileManager
60+
)
61+
return downloader
62+
}
63+
5364
public static func tearDown() {
5465
try? FileManager.default.removeItem(at: Self.downloadsURL)
5566
}
@@ -73,4 +84,11 @@ public enum BatchDownloaderFake {
7384
// MARK: Private
7485

7586
private static let downloads = "downloads"
87+
88+
private static func makeDownloadsPersistence() -> GRDBDownloadsPersistence {
89+
try? DefaultFileSystem().createDirectory(at: Self.downloadsURL, withIntermediateDirectories: true)
90+
let downloadsDBPath = Self.downloadsURL.appendingPathComponent("ongoing-downloads.db", isDirectory: false)
91+
let persistence = GRDBDownloadsPersistence(fileURL: downloadsDBPath.url)
92+
return persistence
93+
}
7694
}

0 commit comments

Comments
 (0)