Skip to content

Commit 0a9bc42

Browse files
flashnoRik BasuZachNagengast
authored
Multi channel audio merging (#320)
* Allow developers to merge audio channels. Developers can pick any individual audio channel, sum all channels, or sum specific channels. Set channelMode in WhisperKitConfig and pass the AudioInputConfig. * Cleanup and formatting. * Added helper function to read audio channel into buffer. This maintains the integrity of the file. Changed unit test to use 10 mins of audio. * Format --------- Co-authored-by: Rik Basu <[email protected]> Co-authored-by: ZachNagengast <[email protected]>
1 parent ca49596 commit 0a9bc42

File tree

7 files changed

+350
-16
lines changed

7 files changed

+350
-16
lines changed

.github/workflows/unit-tests.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ jobs:
3939
name: "watchOS",
4040
condition: "${{ inputs.macos-runner == 'macos-15' }}",
4141
clean-destination: "generic/platform=watchOS",
42-
test-destination: "platform=watchOS Simulator,OS=11.1,name=Apple Watch Ultra 2 (49mm)",
42+
test-destination: "platform=watchOS Simulator,name=Apple Watch Ultra 2 (49mm)",
4343
}
4444
- {
4545
name: "visionOS",

Sources/WhisperKit/Core/Audio/AudioProcessor.swift

Lines changed: 159 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ public typealias DeviceID = AudioDeviceID
1212
#else
1313
public typealias DeviceID = String
1414
#endif
15+
public typealias ChannelMode = AudioInputConfig.ChannelMode
1516

1617
public struct AudioDevice: Identifiable, Hashable {
1718
public let id: DeviceID
@@ -23,19 +24,43 @@ public struct AudioDevice: Identifiable, Hashable {
2324
}
2425
}
2526

27+
/// Configuration for audio input including device selection and channel processing options.
28+
public struct AudioInputConfig {
29+
/// Specifies how to handle audio channels when processing multi-channel audio.
30+
public enum ChannelMode: Hashable, Codable {
31+
/// Selects a single specific channel by index.
32+
/// - Parameter index: The zero-based index of the channel to use.
33+
/// 0 selects the first channel, 1 selects the second, etc.
34+
case specificChannel(Int)
35+
36+
/// Mixes all channels together with peak normalization if parameter is left `nil`.
37+
/// - Parameter channels: Array of zero-based channel indices to mix.
38+
/// For example, `[0, 2]` mixes just the first and third channels.
39+
/// The resulting mono audio will maintain the same peak level as the
40+
/// loudest original channel to prevent clipping.
41+
case sumChannels([Int]?)
42+
}
43+
44+
/// Specifies how to process channels from multi-channel audio sources.
45+
/// Defaults to summing all channels if not explicitly set.
46+
public var channelMode: ChannelMode = .sumChannels(nil)
47+
}
48+
2649
public protocol AudioProcessing {
2750
/// Loads audio data from a specified file path.
2851
/// - Parameters:
2952
/// - audioFilePath: The file path of the audio file.
53+
/// - channelMode: Channel Mode selected for loadAudio
3054
/// - startTime: Optional start time in seconds to read from
3155
/// - endTime: Optional end time in seconds to read until
3256
/// - Returns: `AVAudioPCMBuffer` containing the audio data.
33-
static func loadAudio(fromPath audioFilePath: String, startTime: Double?, endTime: Double?, maxReadFrameSize: AVAudioFrameCount?) throws -> AVAudioPCMBuffer
57+
static func loadAudio(fromPath audioFilePath: String, channelMode: ChannelMode, startTime: Double?, endTime: Double?, maxReadFrameSize: AVAudioFrameCount?) throws -> AVAudioPCMBuffer
3458

3559
/// Loads and converts audio data from a specified file paths.
3660
/// - Parameter audioPaths: The file paths of the audio files.
61+
/// - Parameter channelMode: Channel Mode selected for loadAudio
3762
/// - Returns: Array of `.success` if the file was loaded and converted correctly, otherwise `.failure`
38-
static func loadAudio(at audioPaths: [String]) async -> [Result<[Float], Swift.Error>]
63+
static func loadAudio(at audioPaths: [String], channelMode: ChannelMode) async -> [Result<[Float], Swift.Error>]
3964

4065
/// Pad or trim the audio data to the desired length.
4166
/// - Parameters:
@@ -189,21 +214,22 @@ public class AudioProcessor: NSObject, AudioProcessing {
189214

190215
public static func loadAudio(
191216
fromPath audioFilePath: String,
217+
channelMode: ChannelMode = .sumChannels(nil),
192218
startTime: Double? = 0,
193219
endTime: Double? = nil,
194220
maxReadFrameSize: AVAudioFrameCount? = nil
195221
) throws -> AVAudioPCMBuffer {
196222
guard FileManager.default.fileExists(atPath: audioFilePath) else {
197223
throw WhisperError.loadAudioFailed("Resource path does not exist \(audioFilePath)")
198224
}
199-
200225
let audioFileURL = URL(fileURLWithPath: audioFilePath)
201226
let audioFile = try AVAudioFile(forReading: audioFileURL, commonFormat: .pcmFormatFloat32, interleaved: false)
202-
return try loadAudio(fromFile: audioFile, startTime: startTime, endTime: endTime, maxReadFrameSize: maxReadFrameSize)
227+
return try loadAudio(fromFile: audioFile, channelMode: channelMode, startTime: startTime, endTime: endTime, maxReadFrameSize: maxReadFrameSize)
203228
}
204229

205230
public static func loadAudio(
206231
fromFile audioFile: AVAudioFile,
232+
channelMode: ChannelMode = .sumChannels(nil),
207233
startTime: Double? = 0,
208234
endTime: Double? = nil,
209235
maxReadFrameSize: AVAudioFrameCount? = nil
@@ -241,8 +267,15 @@ public class AudioProcessor: NSObject, AudioProcessing {
241267
outputBuffer = buffer
242268
} else {
243269
// Audio needs resampling to 16khz
244-
let maxReadFrameSize = maxReadFrameSize ?? Constants.defaultAudioReadFrameSize
245-
outputBuffer = resampleAudio(fromFile: audioFile, toSampleRate: 16000, channelCount: 1, frameCount: frameCount, maxReadFrameSize: maxReadFrameSize)
270+
let maxReadSize = maxReadFrameSize ?? Constants.defaultAudioReadFrameSize
271+
outputBuffer = resampleAudio(
272+
fromFile: audioFile,
273+
toSampleRate: 16000,
274+
channelCount: 1,
275+
channelMode: channelMode,
276+
frameCount: frameCount,
277+
maxReadFrameSize: maxReadSize
278+
)
246279
}
247280

248281
if let outputBuffer = outputBuffer {
@@ -259,13 +292,13 @@ public class AudioProcessor: NSObject, AudioProcessing {
259292

260293
public static func loadAudioAsFloatArray(
261294
fromPath audioFilePath: String,
295+
channelMode: ChannelMode = .sumChannels(nil),
262296
startTime: Double? = 0,
263297
endTime: Double? = nil
264298
) throws -> [Float] {
265299
guard FileManager.default.fileExists(atPath: audioFilePath) else {
266300
throw WhisperError.loadAudioFailed("Resource path does not exist \(audioFilePath)")
267301
}
268-
269302
let audioFileURL = URL(fileURLWithPath: audioFilePath)
270303
let audioFile = try AVAudioFile(forReading: audioFileURL, commonFormat: .pcmFormatFloat32, interleaved: false)
271304
let inputSampleRate = audioFile.fileFormat.sampleRate
@@ -287,6 +320,7 @@ public class AudioProcessor: NSObject, AudioProcessing {
287320
try autoreleasepool {
288321
let buffer = try loadAudio(
289322
fromFile: audioFile,
323+
channelMode: channelMode,
290324
startTime: currentTime,
291325
endTime: chunkEnd
292326
)
@@ -301,12 +335,12 @@ public class AudioProcessor: NSObject, AudioProcessing {
301335
return result
302336
}
303337

304-
public static func loadAudio(at audioPaths: [String]) async -> [Result<[Float], Swift.Error>] {
338+
public static func loadAudio(at audioPaths: [String], channelMode: ChannelMode = .sumChannels(nil)) async -> [Result<[Float], Swift.Error>] {
305339
await withTaskGroup(of: [(index: Int, result: Result<[Float], Swift.Error>)].self) { taskGroup -> [Result<[Float], Swift.Error>] in
306340
for (index, audioPath) in audioPaths.enumerated() {
307341
taskGroup.addTask {
308342
do {
309-
let audio = try AudioProcessor.loadAudioAsFloatArray(fromPath: audioPath)
343+
let audio = try AudioProcessor.loadAudioAsFloatArray(fromPath: audioPath, channelMode: channelMode)
310344
return [(index: index, result: .success(audio))]
311345
} catch {
312346
return [(index: index, result: .failure(error))]
@@ -334,6 +368,7 @@ public class AudioProcessor: NSObject, AudioProcessing {
334368
fromFile audioFile: AVAudioFile,
335369
toSampleRate sampleRate: Double,
336370
channelCount: AVAudioChannelCount,
371+
channelMode: ChannelMode = .sumChannels(nil),
337372
frameCount: AVAudioFrameCount? = nil,
338373
maxReadFrameSize: AVAudioFrameCount = Constants.defaultAudioReadFrameSize
339374
) -> AVAudioPCMBuffer? {
@@ -370,7 +405,15 @@ public class AudioProcessor: NSObject, AudioProcessing {
370405

371406
do {
372407
try audioFile.read(into: inputBuffer, frameCount: framesToRead)
373-
guard let resampledChunk = resampleAudio(fromBuffer: inputBuffer,
408+
409+
// Convert to mono if needed
410+
guard let monoChunk = convertToMono(inputBuffer, mode: channelMode) else {
411+
Logging.error("Failed to process audio channels")
412+
return nil
413+
}
414+
415+
// Resample mono audio
416+
guard let resampledChunk = resampleAudio(fromBuffer: monoChunk,
374417
toSampleRate: outputFormat.sampleRate,
375418
channelCount: outputFormat.channelCount)
376419
else {
@@ -461,6 +504,112 @@ public class AudioProcessor: NSObject, AudioProcessing {
461504
return convertedBuffer
462505
}
463506

507+
/// Convert multi channel audio to mono based on the specified mode
508+
/// - Parameters:
509+
/// - buffer: The input audio buffer with multiple channels
510+
/// - mode: The channel processing mode
511+
/// - Returns: A mono-channel audio buffer
512+
public static func convertToMono(_ buffer: AVAudioPCMBuffer, mode: ChannelMode) -> AVAudioPCMBuffer? {
513+
let channelCount = Int(buffer.format.channelCount)
514+
let frameLength = Int(buffer.frameLength)
515+
516+
if channelCount <= 1 {
517+
// Early return, audio is already mono format
518+
return buffer
519+
}
520+
521+
guard let channelData = buffer.floatChannelData else {
522+
Logging.error("Buffer did not contain floatChannelData.")
523+
return nil
524+
}
525+
526+
// Create a new single-channel buffer
527+
guard let monoFormat = AVAudioFormat(
528+
commonFormat: .pcmFormatFloat32,
529+
sampleRate: buffer.format.sampleRate,
530+
channels: 1,
531+
interleaved: false
532+
) else {
533+
Logging.error("Failed to create AVAudioFormat object.")
534+
return nil
535+
}
536+
537+
guard let monoBuffer = AVAudioPCMBuffer(
538+
pcmFormat: monoFormat,
539+
frameCapacity: buffer.frameCapacity
540+
) else {
541+
Logging.error("Failed to create mono buffer.")
542+
return nil
543+
}
544+
545+
monoBuffer.frameLength = buffer.frameLength
546+
547+
// Make sure mono buffer has channel data
548+
guard let monoChannelData = monoBuffer.floatChannelData else { return buffer }
549+
550+
// Clear the buffer to ensure it starts with zeros
551+
vDSP_vclr(monoChannelData[0], 1, vDSP_Length(frameLength))
552+
553+
switch mode {
554+
case let .specificChannel(channelIndex):
555+
// Copy the specified channel, defaulting to first channel if out of range
556+
let safeIndex = (channelIndex >= 0 && channelIndex < channelCount) ? channelIndex : 0
557+
memcpy(monoChannelData[0], channelData[safeIndex], frameLength * MemoryLayout<Float>.size)
558+
559+
case let .sumChannels(channelIndices):
560+
// Determine which channels to sum
561+
let indicesToSum: [Int]
562+
563+
if let indices = channelIndices, !indices.isEmpty {
564+
// Sum specific channels (filter out invalid indices)
565+
indicesToSum = indices.filter { $0 >= 0 && $0 < channelCount }
566+
567+
// Handle case where all specified indices are invalid
568+
if indicesToSum.isEmpty {
569+
memcpy(monoChannelData[0], channelData[0], frameLength * MemoryLayout<Float>.size)
570+
Logging.debug("No valid channel indices provided, defaulting to first channel")
571+
return monoBuffer
572+
}
573+
} else {
574+
// Sum all channels (nil or empty array provided)
575+
indicesToSum = Array(0..<channelCount)
576+
}
577+
578+
// First, find the maximum peak across selected input channels
579+
var maxOriginalPeak: Float = 0.0
580+
for channelIndex in indicesToSum {
581+
var channelPeak: Float = 0.0
582+
vDSP_maxmgv(channelData[channelIndex], 1, &channelPeak, vDSP_Length(frameLength))
583+
maxOriginalPeak = max(maxOriginalPeak, channelPeak)
584+
}
585+
586+
// Sum the specified channels
587+
for channelIndex in indicesToSum {
588+
vDSP_vadd(
589+
monoChannelData[0], 1,
590+
channelData[channelIndex], 1,
591+
monoChannelData[0], 1,
592+
vDSP_Length(frameLength)
593+
)
594+
}
595+
596+
// Find the peak in the mono mix
597+
var monoPeak: Float = 0.0
598+
vDSP_maxmgv(monoChannelData[0], 1, &monoPeak, vDSP_Length(frameLength))
599+
600+
// Scale based on peak ratio (avoid division by zero)
601+
var scale = maxOriginalPeak / max(monoPeak, 0.0001)
602+
vDSP_vsmul(
603+
monoChannelData[0], 1,
604+
&scale,
605+
monoChannelData[0], 1,
606+
vDSP_Length(frameLength)
607+
)
608+
}
609+
610+
return monoBuffer
611+
}
612+
464613
// MARK: - Utility
465614

466615
/// Detect voice activity in the given buffer of relative energy values.
@@ -584,7 +733,6 @@ public class AudioProcessor: NSObject, AudioProcessing {
584733

585734
let frameLength = Int(buffer.frameLength)
586735
let startPointer = channelData[0]
587-
588736
var result: [Float] = []
589737
result.reserveCapacity(frameLength) // Reserve the capacity to avoid multiple allocations
590738

Sources/WhisperKit/Core/Configurations.swift

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,10 @@ open class WhisperKitConfig {
2222

2323
/// Model compute options, see `ModelComputeOptions`
2424
public var computeOptions: ModelComputeOptions?
25+
/// Audio input config to define how to process audio input
26+
public var audioInputConfig: AudioInputConfig?
2527
/// Audio processor for the model
2628
public var audioProcessor: (any AudioProcessing)?
27-
/// Audio processor for the model
2829
public var featureExtractor: (any FeatureExtracting)?
2930
public var audioEncoder: (any AudioEncoding)?
3031
public var textDecoder: (any TextDecoding)?
@@ -53,6 +54,7 @@ open class WhisperKitConfig {
5354
modelFolder: String? = nil,
5455
tokenizerFolder: URL? = nil,
5556
computeOptions: ModelComputeOptions? = nil,
57+
audioInputConfig: AudioInputConfig? = nil,
5658
audioProcessor: (any AudioProcessing)? = nil,
5759
featureExtractor: (any FeatureExtracting)? = nil,
5860
audioEncoder: (any AudioEncoding)? = nil,
@@ -74,6 +76,7 @@ open class WhisperKitConfig {
7476
self.modelFolder = modelFolder
7577
self.tokenizerFolder = tokenizerFolder
7678
self.computeOptions = computeOptions
79+
self.audioInputConfig = audioInputConfig
7780
self.audioProcessor = audioProcessor
7881
self.featureExtractor = featureExtractor
7982
self.audioEncoder = audioEncoder

Sources/WhisperKit/Core/WhisperKit.swift

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ open class WhisperKit {
2020
}
2121

2222
public var modelCompute: ModelComputeOptions
23+
public var audioInputConfig: AudioInputConfig
2324
public var tokenizer: WhisperTokenizer?
2425

2526
/// Protocols
@@ -52,6 +53,7 @@ open class WhisperKit {
5253

5354
public init(_ config: WhisperKitConfig = WhisperKitConfig()) async throws {
5455
modelCompute = config.computeOptions ?? ModelComputeOptions()
56+
audioInputConfig = config.audioInputConfig ?? AudioInputConfig()
5557
audioProcessor = config.audioProcessor ?? AudioProcessor()
5658
featureExtractor = config.featureExtractor ?? FeatureExtractor()
5759
audioEncoder = config.audioEncoder ?? AudioEncoder()
@@ -584,7 +586,7 @@ open class WhisperKit {
584586
let loadAudioStart = Date()
585587

586588
// Load and extract audio data from the provided file paths
587-
let loadedAudioResult = await AudioProcessor.loadAudio(at: audioPaths)
589+
let loadedAudioResult = await AudioProcessor.loadAudio(at: audioPaths, channelMode: audioInputConfig.channelMode)
588590
let audioArrays = loadedAudioResult.compactMap { try? $0.get() }
589591

590592
// Calculate the time taken to load and convert audio
@@ -780,8 +782,7 @@ open class WhisperKit {
780782
Logging.debug("Audio loading and convert time: \(convertTime)")
781783
logCurrentMemoryUsage("Audio Loading and Convert")
782784
}
783-
784-
return try AudioProcessor.loadAudioAsFloatArray(fromPath: audioPath)
785+
return try AudioProcessor.loadAudioAsFloatArray(fromPath: audioPath, channelMode: audioInputConfig.channelMode)
785786
}
786787

787788
transcriptionStateCallback?(.transcribing)
Binary file not shown.

0 commit comments

Comments
 (0)