Skip to content

Commit

Permalink
Use multiple threads for prefetching.
Browse files Browse the repository at this point in the history
Signed-off-by: Pascal Spörri <[email protected]>
  • Loading branch information
pspoerri committed Sep 7, 2023
1 parent 18191d1 commit 7f6db51
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 22 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ Changing these values might have an impact on performance.

- `spark.shuffle.s3.bufferSize`: Default buffer size when writing (default: `8388608`)
- `spark.shuffle.s3.maxBufferSizeTask`: Maximum size of the buffered output streams per task (default: `134217728`)
- `spark.shuffle.s3.prefetchConcurrencyTask`: The per-task concurrency when prefetching (default: `2`).
- `spark.shuffle.s3.cachePartitionLengths`: Cache partition lengths in memory (default: `true`)
- `spark.shuffle.s3.cacheChecksums`: Cache checksums in memory (default: `true`)
- `spark.shuffle.s3.cleanup`: Cleanup the shuffle files (default: `true`)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ class S3ShuffleDispatcher extends Logging {
// Optional
val bufferSize: Int = conf.getInt("spark.shuffle.s3.bufferSize", defaultValue = 8 * 1024 * 1024)
val maxBufferSizeTask: Int = conf.getInt("spark.shuffle.s3.maxBufferSizeTask", defaultValue = 128 * 1024 * 1024)
val prefetchConcurrencyTask: Int = conf.getInt("spark.shuffle.s3.prefetchConcurrencyTask", defaultValue = 2)
val cachePartitionLengths: Boolean = conf.getBoolean("spark.shuffle.s3.cachePartitionLengths", defaultValue = true)
val cacheChecksums: Boolean = conf.getBoolean("spark.shuffle.s3.cacheChecksums", defaultValue = true)
val cleanupShuffleFiles: Boolean = conf.getBoolean("spark.shuffle.s3.cleanup", defaultValue = true)
Expand Down Expand Up @@ -60,6 +61,7 @@ class S3ShuffleDispatcher extends Logging {
// Optional
logInfo(s"- spark.shuffle.s3.bufferSize=${bufferSize}")
logInfo(s"- spark.shuffle.s3.maxBufferSizeTask=${maxBufferSizeTask}")
logInfo(s"- spark.shuffle.s3.prefetchConcurrencyTask=${prefetchConcurrencyTask}")
logInfo(s"- spark.shuffle.s3.cachePartitionLengths=${cachePartitionLengths}")
logInfo(s"- spark.shuffle.s3.cacheChecksums=${cacheChecksums}")
logInfo(s"- spark.shuffle.s3.cleanup=${cleanupShuffleFiles}")
Expand Down Expand Up @@ -112,7 +114,7 @@ class S3ShuffleDispatcher extends Logging {
def openBlock(blockId: BlockId): FSDataInputStream = {
val status = getFileStatusCached(blockId)
val builder = fs.openFile(status.getPath).withFileStatus(status)
val stream = builder.build().get()
val stream = builder.build().get()
if (canSetReadahead) {
stream.setReadahead(0)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,35 +6,46 @@
package org.apache.spark.storage

import org.apache.spark.internal.Logging
import org.apache.spark.shuffle.helper.S3ShuffleDispatcher

import java.io.{BufferedInputStream, InputStream}
import java.util

class S3BufferedPrefetchIterator(iter: Iterator[(BlockId, S3ShuffleBlockStream)], maxBufferSize: Long) extends Iterator[(BlockId, InputStream)] with Logging {

private val concurrencyTask = S3ShuffleDispatcher.get.prefetchConcurrencyTask
private val startTime = System.nanoTime()

@volatile private var memoryUsage: Long = 0
@volatile private var hasItem: Boolean = iter.hasNext
private var timeWaiting: Long = 0
private var timePrefetching: Long = 0
private var timeNext: Long = 0
private var numStreams: Long = 0
private var bytesRead: Long = 0

private var nextElement: (BlockId, S3ShuffleBlockStream) = null
private var activeTasks: Long = 0

private val completed = new util.LinkedList[(InputStream, BlockId, Long)]()

private def prefetchThread(): Unit = {
while (iter.hasNext || nextElement != null) {
if (nextElement == null) {
val now = System.nanoTime()
nextElement = iter.next()
timeNext = System.nanoTime() - now
var nextElement: (BlockId, S3ShuffleBlockStream) = null
while (true) {
synchronized {
if (!iter.hasNext && nextElement == null) {
hasItem = false
return
}
if (nextElement == null) {
nextElement = iter.next()
activeTasks += 1
hasItem = iter.hasNext
}
}
val bsize = scala.math.min(maxBufferSize, nextElement._2.maxBytes).toInt

var fetchNext = false
val bsize = scala.math.min(maxBufferSize, nextElement._2.maxBytes).toInt
synchronized {
if (memoryUsage + math.min(bsize, maxBufferSize) > maxBufferSize) {
if (memoryUsage + bsize > maxBufferSize) {
try {
wait()
}
Expand All @@ -43,6 +54,7 @@ class S3BufferedPrefetchIterator(iter: Iterator[(BlockId, S3ShuffleBlockStream)]
}
} else {
fetchNext = true
memoryUsage += bsize
}
}

Expand All @@ -59,50 +71,49 @@ class S3BufferedPrefetchIterator(iter: Iterator[(BlockId, S3ShuffleBlockStream)]
timePrefetching += System.nanoTime() - now
bytesRead += bsize
synchronized {
memoryUsage += bsize
completed.push((stream, block, bsize))
hasItem = iter.hasNext
notify()
activeTasks -= 1
notifyAll()
}
}
}
}

private val self = this
private val thread = new Thread {
private val threads = Array.fill[Thread](concurrencyTask)(new Thread {
override def run(): Unit = {
self.prefetchThread()
}
}
thread.start()
})
threads.foreach(_.start())

private def printStatistics(): Unit = synchronized {
val totalRuntime = System.nanoTime() - startTime
try {
val tR = totalRuntime / 1000000
val wPer = 100 * timeWaiting / totalRuntime
val tW = timeWaiting / 1000000
val tP = timePrefetching / 1000000
val tN = timeNext / 1000000
val bR = bytesRead
val r = numStreams
// Average time per prefetch
val atP = tP / r
// Average time waiting
val atW = tW / r
// Average time next
val atN = tN / r
// Average read bandwidth
val bW = bR.toDouble / (tP.toDouble / 1000) / (1024 * 1024)
// Block size
val bs = bR / r
logInfo(s"Statistics: ${bR} bytes, ${tW} ms waiting (${atW} avg), " +
s"${tP} ms prefetching (avg: ${atP} ms - ${bs} block size - ${bW} MiB/s) " +
s"${tN} ms for next (${atN} avg)")
s"${tP} ms prefetching (avg: ${atP} ms - ${bs} block size - ${bW} MiB/s). " +
s"Total: ${tR} ms - ${wPer}% waiting")
} catch {
case e: Exception => logError(f"Unable to print statistics: ${e.getMessage}.")
}
}

override def hasNext: Boolean = synchronized {
val result = hasItem || (completed.size() > 0)
val result = hasItem || activeTasks > 0 || (completed.size() > 0)
if (!result) {
printStatistics()
}
Expand Down

0 comments on commit 7f6db51

Please sign in to comment.