Skip to content

Commit

Permalink
Fix crash when getOnGoingDownloads is called early
Browse files Browse the repository at this point in the history
Sometimes `DownloadManager.getOnGoingDownloads` may be called earlier
than `DownloadManager.start`. This result in a crash.

This change ensures that getOnGoingDownloads awaits the initialization.
  • Loading branch information
mohamede1945 committed Nov 26, 2023
1 parent 4f1a562 commit 69274ab
Show file tree
Hide file tree
Showing 4 changed files with 106 additions and 30 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,40 @@ struct SingleTaskResponse {
let response: DownloadBatchResponse
}

// Non thread safe since it's going to be used by an actor
private struct Initializer {
// MARK: Lifecycle

init() {
var continuation: AsyncStream<Void>.Continuation!
let stream = AsyncStream<Void> { continuation = $0 }
self.continuation = continuation
self.stream = stream
}

// MARK: Internal

private(set) var initialized = false

mutating func initialize() {
initialized = true
continuation.finish()
}

func awaitInitialization() async {
if initialized {
return
}
// Wait until the stream finishes
for await _ in stream {}
}

// MARK: Private

private let continuation: AsyncStream<Void>.Continuation
private let stream: AsyncStream<Void>
}

actor DownloadBatchDataController {
// MARK: Lifecycle

Expand All @@ -43,31 +77,23 @@ actor DownloadBatchDataController {

// MARK: Internal

func bootstrapPersistence() async {
do {
try await attempt(times: 3) {
try await loadBatchesFromPersistence()
}
} catch {
crasher.recordError(error, reason: "Failed to retrieve initial download batches from persistence.")
}
}

func start(with session: NetworkSession) async {
await bootstrapPersistence()

self.session = session
let (_, _, downloadTasks) = await session.tasks()
for batch in batches {
await batch.associateTasks(downloadTasks)
}

loadedInitialRunningTasks = true
initialRunningTasks.initialize()

// start pending tasks if needed
await startPendingTasksIfNeeded()
}

func getOnGoingDownloads() -> [DownloadBatchResponse] {
precondition(loadedInitialRunningTasks)
func getOnGoingDownloads() async -> [DownloadBatchResponse] {
await initialRunningTasks.awaitInitialization()
return Array(batches)
}

Expand Down Expand Up @@ -118,7 +144,7 @@ actor DownloadBatchDataController {

private var batches: Set<DownloadBatchResponse> = []

private var loadedInitialRunningTasks = false
private var initialRunningTasks = Initializer()

private var runningTasks: Int {
get async {
Expand All @@ -139,6 +165,16 @@ actor DownloadBatchDataController {
}
}

private func bootstrapPersistence() async {
do {
try await attempt(times: 3) {
try await loadBatchesFromPersistence()
}
} catch {
crasher.recordError(error, reason: "Failed to retrieve initial download batches from persistence.")
}
}

private func loadBatchesFromPersistence() async throws {
let batches = try await persistence.retrieveAll()
logger.info("Loading \(batches.count) from persistence")
Expand Down Expand Up @@ -172,7 +208,7 @@ actor DownloadBatchDataController {
}

private func startPendingTasksIfNeeded() async {
if !loadedInitialRunningTasks {
if !initialRunningTasks.initialized {
return
}

Expand Down
26 changes: 15 additions & 11 deletions Data/BatchDownloader/Sources/Downloader/DownloadManager.swift
Original file line number Diff line number Diff line change
Expand Up @@ -67,17 +67,7 @@ public final class DownloadManager: Sendable {

public func start() async {
logger.info("Starting download manager")
let operationQueue = OperationQueue()
operationQueue.name = "com.quran.downloads"
operationQueue.maxConcurrentOperationCount = 1

let dispatchQueue = DispatchQueue(label: "com.quran.downloads.dispatch")
operationQueue.underlyingQueue = dispatchQueue

await dataController.bootstrapPersistence()

let session = sessionFactory(handler, operationQueue)
self.session = session
let session = createSession()
await dataController.start(with: session)
logger.info("Download manager start completed")
}
Expand All @@ -101,4 +91,18 @@ public final class DownloadManager: Sendable {
private var session: NetworkSession?
private let handler: DownloadSessionDelegate
private let dataController: DownloadBatchDataController

private func createSession() -> NetworkSession {
let operationQueue = OperationQueue()
operationQueue.name = "com.quran.downloads"
operationQueue.maxConcurrentOperationCount = 1

let dispatchQueue = DispatchQueue(label: "com.quran.downloads.dispatch")
operationQueue.underlyingQueue = dispatchQueue

let session = sessionFactory(handler, operationQueue)
self.session = session

return session
}
}
18 changes: 18 additions & 0 deletions Data/BatchDownloader/Tests/DownloadManagerTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,24 @@ final class DownloadManagerTests: XCTestCase {
XCTAssertEqual(calls.calls, 1)
}

func test_onGoingDownloads_whileStartNotFinished() async throws {
// Load a single batch
let batch = DownloadBatchRequest(requests: [request1])
_ = try await downloader.download(batch)

// Deallocate downloader & create new one
downloader = nil
downloader = await BatchDownloaderFake.makeDownloaderDontWaitForSession()

// Test calling getOnGoingDownloads and start at the same time.
async let startTask: () = await downloader.start()
async let downloadsTask = await downloader.getOnGoingDownloads()
let (downloads, _) = await (downloadsTask, startTask)

// Verify
XCTAssertEqual(downloads.count, 1)
}

func testLoadingOnGoingDownload() async throws {
let emptyDownloads = await downloader.getOnGoingDownloads()
XCTAssertEqual(emptyDownloads.count, 0)
Expand Down
26 changes: 22 additions & 4 deletions Data/BatchDownloaderFake/BatchDownloaderFake.swift
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,7 @@ public enum BatchDownloaderFake {
public static let downloadsURL = RelativeFilePath(downloads, isDirectory: true)

public static func makeDownloader(downloads: [SessionTask] = [], fileManager: FileSystem = DefaultFileSystem()) async -> (DownloadManager, NetworkSessionFake) {
try? DefaultFileSystem().createDirectory(at: Self.downloadsURL, withIntermediateDirectories: true)
let downloadsDBPath = Self.downloadsURL.appendingPathComponent("ongoing-downloads.db", isDirectory: false)

let persistence = GRDBDownloadsPersistence(fileURL: downloadsDBPath.url)
let persistence = makeDownloadsPersistence()
actor SessionActor {
var session: NetworkSessionFake!
let channel = AsyncChannel<Void>()
Expand All @@ -50,6 +47,20 @@ public enum BatchDownloaderFake {
return (downloader, await sessionActor.session)
}

public static func makeDownloaderDontWaitForSession(downloads: [SessionTask] = [], fileManager: FileSystem = DefaultFileSystem()) async -> DownloadManager {
let persistence = makeDownloadsPersistence()
let downloader = DownloadManager(
maxSimultaneousDownloads: maxSimultaneousDownloads,
sessionFactory: { delegate, queue in
let session = NetworkSessionFake(queue: queue, delegate: delegate, downloads: downloads)
return session
},
persistence: persistence,
fileManager: fileManager
)
return downloader
}

public static func tearDown() {
try? FileManager.default.removeItem(at: Self.downloadsURL)
}
Expand All @@ -73,4 +84,11 @@ public enum BatchDownloaderFake {
// MARK: Private

private static let downloads = "downloads"

private static func makeDownloadsPersistence() -> GRDBDownloadsPersistence {
try? DefaultFileSystem().createDirectory(at: Self.downloadsURL, withIntermediateDirectories: true)
let downloadsDBPath = Self.downloadsURL.appendingPathComponent("ongoing-downloads.db", isDirectory: false)
let persistence = GRDBDownloadsPersistence(fileURL: downloadsDBPath.url)
return persistence
}
}

0 comments on commit 69274ab

Please sign in to comment.