From c3948c86742cbdb7476f0e3ad98e5836b04626a8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=BA=90=E6=96=87=E9=9B=A8?= <41315874+fumiama@users.noreply.github.com> Date: Thu, 28 Nov 2024 01:18:12 +0900 Subject: [PATCH] feat(tools): load_audio supports mean of stereo --- tools/audio/av.py | 82 +++++++++++++++++++++++++++++++---------------- 1 file changed, 54 insertions(+), 28 deletions(-) diff --git a/tools/audio/av.py b/tools/audio/av.py index a38b3379a..2c14baf9e 100644 --- a/tools/audio/av.py +++ b/tools/audio/av.py @@ -1,8 +1,9 @@ from io import BufferedWriter, BytesIO from pathlib import Path -from typing import Dict +from typing import Dict, Tuple, Optional, Union, List import av +from av.audio.frame import AudioFrame from av.audio.resampler import AudioResampler import numpy as np @@ -39,41 +40,66 @@ def wav2(i: BytesIO, o: BufferedWriter, format: str): inp.close() -def load_audio(file: str, sr: int) -> np.ndarray: +def load_audio( + file: Union[str, BytesIO, Path], + sr: Optional[int]=None, + format: Optional[str]=None, + mono=True + ) -> Union[np.ndarray, Tuple[np.ndarray, int]]: """ https://github.com/fumiama/Retrieval-based-Voice-Conversion-WebUI/blob/412a9950a1e371a018c381d1bfb8579c4b0de329/infer/lib/audio.py#L39 """ - - if not Path(file).exists(): + if (isinstance(file, str) and not Path(file).exists()) or (isinstance(file, Path) and not file.exists()): raise FileNotFoundError(f"File not found: {file}") + rate = 0 + + container = av.open(file, format=format) + audio_stream = next(s for s in container.streams if s.type == "audio") + channels = 1 if audio_stream.layout == "mono" else 2 + container.seek(0) + resampler = AudioResampler(format="fltp", layout=audio_stream.layout, rate=sr) if sr is not None else None + + # Estimated maximum total number of samples to pre-allocate the array + # AV stores length in microseconds by default + estimated_total_samples = int(container.duration * sr // 1_000_000) if sr is not None else 48000 + decoded_audio = np.zeros(estimated_total_samples + 1 if channels == 1 else (channels, estimated_total_samples + 1), dtype=np.float32) + + offset = 0 + + def process_packet(packet: List[AudioFrame]): + frames_data = [] + rate = 0 + for frame in packet: + frame.pts = None # 清除时间戳,避免重新采样问题 + resampled_frames = resampler.resample(frame) if resampler is not None else [frame] + for resampled_frame in resampled_frames: + frame_data = resampled_frame.to_ndarray() + rate = resampled_frame.rate + frames_data.append(frame_data) + return (rate, frames_data) - try: - container = av.open(file) - resampler = AudioResampler(format="fltp", layout="mono", rate=sr) + def frame_iter(container): + for p in container.demux(container.streams.audio[0]): + yield p.decode() - # Estimated maximum total number of samples to pre-allocate the array - # AV stores length in microseconds by default - estimated_total_samples = int(container.duration * sr // 1_000_000) - decoded_audio = np.zeros(estimated_total_samples + 1, dtype=np.float32) + for r, frames_data in map(process_packet, frame_iter(container)): + if not rate: rate = r + for frame_data in frames_data: + end_index = offset + len(frame_data[0]) - offset = 0 - for frame in container.decode(audio=0): - frame.pts = None # Clear presentation timestamp to avoid resampling issues - resampled_frames = resampler.resample(frame) - for resampled_frame in resampled_frames: - frame_data = resampled_frame.to_ndarray()[0] - end_index = offset + len(frame_data) + # 检查 decoded_audio 是否有足够的空间,并在必要时调整大小 + if end_index > decoded_audio.shape[1]: + decoded_audio = np.resize(decoded_audio, (decoded_audio.shape[0], end_index*4)) - # Check if decoded_audio has enough space, and resize if necessary - if end_index > decoded_audio.shape[0]: - decoded_audio = np.resize(decoded_audio, end_index + 1) + np.copyto(decoded_audio[..., offset:end_index], frame_data) + offset += len(frame_data[0]) - decoded_audio[offset:end_index] = frame_data - offset += len(frame_data) + # Truncate the array to the actual size + decoded_audio = decoded_audio[..., :offset] - # Truncate the array to the actual size - decoded_audio = decoded_audio[:offset] - except Exception as e: - raise RuntimeError(f"Failed to load audio: {e}") + if mono and decoded_audio.shape[0] > 1: + decoded_audio = decoded_audio.mean(0) - return decoded_audio + if sr is not None: + return decoded_audio + return decoded_audio, rate \ No newline at end of file