Skip to content

Commit 7a1711f

Browse files
committed
Feature/adaptive temporal masks (#7)
* introduce adaptive_temporal_masks to do mask depending on speech length * bump gx version
1 parent 7192a2c commit 7a1711f

File tree

2 files changed

+16
-2
lines changed

2 files changed

+16
-2
lines changed

nemo/collections/asr/parts/submodules/spectr_augment.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,14 @@ def __init__(
8383
self.mask_value = mask_value
8484
self.use_vectorized_code = use_vectorized_code
8585

86+
if isinstance(time_masks, int):
87+
self.adaptive_temporal_masks = False
88+
else:
89+
if time_masks > 1.0 or time_masks < 0.0:
90+
raise ValueError("If `time_masks` is a float value, must be in range [0, 1]")
91+
92+
self.adaptive_temporal_masks = True
93+
8694
if isinstance(time_width, int):
8795
self.adaptive_temporal_width = False
8896
else:
@@ -114,6 +122,12 @@ def _forward_legacy(self, input_spec, length):
114122
width = self._rng.randint(0, self.freq_width)
115123
fill_mask[idx, start : start + width, :] = True
116124

125+
# Derive the number of masks, sometimes based percentage of input length.
126+
if self.adaptive_temporal_masks:
127+
time_max_masks = int(lengths_cpu[idx] * self.time_masks)
128+
else:
129+
time_max_masks = self.time_masks
130+
117131
# Derive time width, sometimes based percentage of input length.
118132
if self.adaptive_temporal_width:
119133
time_max_width = max(1, int(lengths_cpu[idx] * self.time_width))
@@ -122,7 +136,7 @@ def _forward_legacy(self, input_spec, length):
122136
time_start_upper_bound = max(1, lengths_cpu[idx] - time_max_width)
123137

124138
# Set time masking
125-
for _ in range(self.time_masks):
139+
for _ in range(time_max_masks):
126140
start = self._rng.randint(0, time_start_upper_bound)
127141
width = self._rng.randint(0, time_max_width)
128142
fill_mask[idx, :, start : start + width] = True

nemo/package_info.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
MAJOR = 2
1717
MINOR = 2
1818
PATCH = 0
19-
PRE_RELEASE = ''
19+
PRE_RELEASE = '+gx1'
2020

2121
# Use the following formatting: (major, minor, patch, pre-release)
2222
VERSION = (MAJOR, MINOR, PATCH, PRE_RELEASE)

0 commit comments

Comments
 (0)