Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

cache background_noise rms data #145

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 37 additions & 11 deletions torch_audiomentations/augmentations/background_noise.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def __init__(
sample_rate: int = None,
target_rate: int = None,
output_type: Optional[str] = None,
cache_background_data: bool = False
):
"""

Expand All @@ -49,6 +50,7 @@ def __init__(
:param p:
:param p_mode:
:param sample_rate:
:param cache_background_data
"""

super().__init__(
Expand All @@ -62,7 +64,10 @@ def __init__(

# TODO: check that one can read audio files
self.background_paths = find_audio_files_in_paths(background_paths)

self.background_rms_data = {}
self.cache_background_data = cache_background_data
if self.cache_background_data:
print("background_paths:", len(self.background_paths))
if sample_rate is not None:
self.audio = Audio(sample_rate=sample_rate, mono=True)

Expand All @@ -81,30 +86,51 @@ def random_background(self, audio: Audio, target_num_samples: int) -> torch.Tens

missing_num_samples = target_num_samples
while missing_num_samples > 0:

background_path = random.choice(self.background_paths)
background_num_samples = audio.get_num_samples(background_path)

if self.cache_background_data:
if background_path not in self.background_rms_data:
background_num_samples = audio.get_num_samples(background_path)
self.background_rms_data[background_path] = (audio.rms_normalize(audio(background_path)), background_num_samples)
else:
background_num_samples = self.background_rms_data[background_path][1]
else:
background_num_samples = audio.get_num_samples(background_path)

if background_num_samples > missing_num_samples:
sample_offset = random.randint(
0, background_num_samples - missing_num_samples
)
num_samples = missing_num_samples
background_samples = audio(
background_path, sample_offset=sample_offset, num_samples=num_samples
)
if self.cache_background_data:
background_samples = self.background_rms_data[background_path][0][:, sample_offset:sample_offset + num_samples]
else:
background_samples = audio(
background_path, sample_offset=sample_offset, num_samples=num_samples
)
missing_num_samples = 0
else:
background_samples = audio(background_path)
if self.cache_background_data:
background_samples = self.background_rms_data[background_path][0]
else:
background_samples = audio(background_path)
missing_num_samples -= background_num_samples

pieces.append(background_samples)

# the inner call to rms_normalize ensures concatenated pieces share the same RMS (1)
# the outer call to rms_normalize ensures that the resulting background has an RMS of 1
# (this simplifies "apply_transform" logic)
return audio.rms_normalize(
torch.cat([audio.rms_normalize(piece) for piece in pieces], dim=1)
)
if self.cache_background_data:
ret = audio.rms_normalize(
torch.cat(pieces, dim=1)
)
return ret
else:
return audio.rms_normalize(
torch.cat([audio.rms_normalize(piece) for piece in pieces], dim=1)
)

def randomize_parameters(
self,
Expand Down Expand Up @@ -167,8 +193,8 @@ def apply_transform(

return ObjectDict(
samples=samples
+ background_rms.unsqueeze(-1)
* background.view(batch_size, 1, num_samples).expand(-1, num_channels, -1),
+ background_rms.unsqueeze(-1)
* background.view(batch_size, 1, num_samples).expand(-1, num_channels, -1),
sample_rate=sample_rate,
targets=targets,
target_rate=target_rate,
Expand Down