Skip to content

Commit

Permalink
Add split audio files
Browse files Browse the repository at this point in the history
  • Loading branch information
pooya-mohammadi committed Aug 1, 2024
1 parent 7f9bc35 commit 5d39560
Show file tree
Hide file tree
Showing 3 changed files with 88 additions and 31 deletions.
2 changes: 1 addition & 1 deletion deep_utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from .utils.lib_utils.integeration_utils import import_lazy_module

# Deep Utils version number
__version__ = "1.3.36"
__version__ = "1.3.37"

from .utils.constants import DUMMY_PATH, Backends

Expand Down
115 changes: 86 additions & 29 deletions deep_utils/audio/audio_utils/torchaudio_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import math
from pathlib import Path
from typing import Union, List, Optional
from typing import Union, List, Optional, Tuple, overload

import numpy as np
import torch
Expand Down Expand Up @@ -86,51 +86,108 @@ def load(wave: Union[str, Path], sr: int, resample_rate: int = 0, resampled_path

@staticmethod
def split(
wave, sr: Optional[int] = None, max_seconds: float = 10, min_seconds=1, logger=None, verbose=1
) -> List[torch.Tensor]:
wave: str, durations: List[Tuple[int, int]] = None, sr: Optional[int] = None, max_seconds: int = None,
min_seconds: int = 1, overlap: int = 1, logger=None,
verbose=1
) -> Tuple[List[torch.Tensor], int]:
"""
Splits a wave to mini-waves based on input max_seconds. If the last segment's duration is less than min_seconds
it is combined with its previous segment.
:param wave:
:param sr:
:param max_seconds:
:param durations:
:param min_seconds:
:param max_seconds:
:param overlap:
:param logger:
:param verbose:
:return:
"""
wave, _ = TorchAudioUtils.load(wave, sr=sr)

wave_duration = TorchAudioUtils.get_duration(wave, sr)
if max_seconds is None or wave_duration < max_seconds:
return [wave]
def _split(wave: torch.Tensor, wave_duration, max_seconds, sr, unsqueeze: bool, overlap: int) -> List[
torch.Tensor]:

if max_seconds is None or wave_duration < max_seconds:
# wave, _ = TorchAudioUtils.load(wave, sr=sr)
wave = wave.unsqueeze(0) if unsqueeze else wave
return [wave]

n_intervals = math.ceil(wave_duration / max_seconds)
waves_list = []

for interval in range(n_intervals):
if overlap and interval != 0:
s = int((interval * max_seconds - overlap) * sr)
else:
s = int((interval * max_seconds) * sr)
e = int(((interval + 1) * max_seconds) * sr)
w = wave[s:e]
w_duration = TorchAudioUtils.get_duration(w, sr)
w = w.unsqueeze(0) if unsqueeze else w
if w_duration < min_seconds:
waves_list[-1] = torch.concat([waves_list[-1], w], dim=1)
else:
waves_list.append(w)
return waves_list

wave, sr_ = TorchAudioUtils.load(wave, sr=sr)
sr = sr or sr_

unsqueeze = False
if len(wave.shape) == 2:
if len(wave.shape) == 2 and wave.shape[0] == 1:
wave = wave.squeeze(0)
unsqueeze = True
elif len(wave.shape) == 2 and wave.shape[0] == 2:
wave = torch.mean(wave, dim=0)
unsqueeze = True
elif len(wave.shape) == 2:
raise ValueError(f"[ERROR] Input wave shape is: {wave.shape}")
elif len(wave.shape) == 1:
unsqueeze = True
n_intervals = math.ceil(wave_duration / max_seconds)
waves = []

for interval in range(n_intervals):
s = (interval * max_seconds) * sr
e = ((interval + 1) * max_seconds) * sr
w = wave[s:e]
w_duration = TorchAudioUtils.get_duration(w, sr)
w = w.unsqueeze(0) if unsqueeze else w
if w_duration < min_seconds:
waves[-1] = torch.concat([waves[-1], w], dim=1)
else:
waves.append(w)
if len(waves) > 1:
log_print(
logger,
f"Successfully split input wave to {len(waves)} waves!",
verbose=verbose,
)
return waves
# n_intervals = math.ceil(wave_duration / max_seconds)
if durations is not None:
waves = []
carry = None
for s, e in durations:
w_duration = e - s
s = int(s * sr)
e = int(e * sr)
# s = (interval * max_seconds) * sr
# e = ((interval + 1) * max_seconds) * sr
w = wave[s:e]

if max_seconds:
waves.extend(_split(w, w_duration, max_seconds, sr, unsqueeze, overlap))
else:
w = w.unsqueeze(0) if unsqueeze else w
if w_duration < min_seconds:
if carry:
carry = torch.concat([carry, w], dim=1)
else:
carry = w
else:
if carry is not None:
w = torch.concat([carry, w], dim=1)
carry = None
waves.append(w)
if len(waves) > 1:
log_print(
logger,
f"Successfully split input wave to {len(waves)} waves!",
verbose=verbose,
)
return waves, sr
elif max_seconds is not None:

wave_duration = TorchAudioUtils.get_duration(wave, sr)
waves = _split(wave, wave_duration, max_seconds, sr, unsqueeze, overlap)
if len(waves) > 1:
log_print(
logger,
f"Successfully split input wave to {len(waves)} waves!",
verbose=verbose,
)
return waves

@staticmethod
def get_duration(wave, sr):
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import setuptools

VERSION = "1.3.36"
VERSION = "1.3.37"

long_description = open("Readme.md", mode="r", encoding="utf-8").read()

Expand Down

0 comments on commit 5d39560

Please sign in to comment.