From eaaa855a2b01c2cbc566c342b2316d43c1015067 Mon Sep 17 00:00:00 2001 From: rqg Date: Fri, 17 Jun 2022 17:15:38 +0800 Subject: [PATCH] cache background_noise rms data --- .../augmentations/background_noise.py | 48 ++++++++++++++----- 1 file changed, 37 insertions(+), 11 deletions(-) diff --git a/torch_audiomentations/augmentations/background_noise.py b/torch_audiomentations/augmentations/background_noise.py index 40c24b86..7fa7b41a 100644 --- a/torch_audiomentations/augmentations/background_noise.py +++ b/torch_audiomentations/augmentations/background_noise.py @@ -38,6 +38,7 @@ def __init__( sample_rate: int = None, target_rate: int = None, output_type: Optional[str] = None, + cache_background_data: bool = False ): """ @@ -49,6 +50,7 @@ def __init__( :param p: :param p_mode: :param sample_rate: + :param cache_background_data """ super().__init__( @@ -62,7 +64,10 @@ def __init__( # TODO: check that one can read audio files self.background_paths = find_audio_files_in_paths(background_paths) - + self.background_rms_data = {} + self.cache_background_data = cache_background_data + if self.cache_background_data: + print("background_paths:", len(self.background_paths)) if sample_rate is not None: self.audio = Audio(sample_rate=sample_rate, mono=True) @@ -81,20 +86,35 @@ def random_background(self, audio: Audio, target_num_samples: int) -> torch.Tens missing_num_samples = target_num_samples while missing_num_samples > 0: + background_path = random.choice(self.background_paths) - background_num_samples = audio.get_num_samples(background_path) + + if self.cache_background_data: + if background_path not in self.background_rms_data: + background_num_samples = audio.get_num_samples(background_path) + self.background_rms_data[background_path] = (audio.rms_normalize(audio(background_path)), background_num_samples) + else: + background_num_samples = self.background_rms_data[background_path][1] + else: + background_num_samples = audio.get_num_samples(background_path) if background_num_samples > missing_num_samples: sample_offset = random.randint( 0, background_num_samples - missing_num_samples ) num_samples = missing_num_samples - background_samples = audio( - background_path, sample_offset=sample_offset, num_samples=num_samples - ) + if self.cache_background_data: + background_samples = self.background_rms_data[background_path][0][:, sample_offset:sample_offset + num_samples] + else: + background_samples = audio( + background_path, sample_offset=sample_offset, num_samples=num_samples + ) missing_num_samples = 0 else: - background_samples = audio(background_path) + if self.cache_background_data: + background_samples = self.background_rms_data[background_path][0] + else: + background_samples = audio(background_path) missing_num_samples -= background_num_samples pieces.append(background_samples) @@ -102,9 +122,15 @@ def random_background(self, audio: Audio, target_num_samples: int) -> torch.Tens # the inner call to rms_normalize ensures concatenated pieces share the same RMS (1) # the outer call to rms_normalize ensures that the resulting background has an RMS of 1 # (this simplifies "apply_transform" logic) - return audio.rms_normalize( - torch.cat([audio.rms_normalize(piece) for piece in pieces], dim=1) - ) + if self.cache_background_data: + ret = audio.rms_normalize( + torch.cat(pieces, dim=1) + ) + return ret + else: + return audio.rms_normalize( + torch.cat([audio.rms_normalize(piece) for piece in pieces], dim=1) + ) def randomize_parameters( self, @@ -167,8 +193,8 @@ def apply_transform( return ObjectDict( samples=samples - + background_rms.unsqueeze(-1) - * background.view(batch_size, 1, num_samples).expand(-1, num_channels, -1), + + background_rms.unsqueeze(-1) + * background.view(batch_size, 1, num_samples).expand(-1, num_channels, -1), sample_rate=sample_rate, targets=targets, target_rate=target_rate,