@@ -83,6 +83,14 @@ def __init__(
83
83
self .mask_value = mask_value
84
84
self .use_vectorized_code = use_vectorized_code
85
85
86
+ if isinstance (time_masks , int ):
87
+ self .adaptive_temporal_masks = False
88
+ else :
89
+ if time_masks > 1.0 or time_masks < 0.0 :
90
+ raise ValueError ("If `time_masks` is a float value, must be in range [0, 1]" )
91
+
92
+ self .adaptive_temporal_masks = True
93
+
86
94
if isinstance (time_width , int ):
87
95
self .adaptive_temporal_width = False
88
96
else :
@@ -114,6 +122,12 @@ def _forward_legacy(self, input_spec, length):
114
122
width = self ._rng .randint (0 , self .freq_width )
115
123
fill_mask [idx , start : start + width , :] = True
116
124
125
+ # Derive the number of masks, sometimes based percentage of input length.
126
+ if self .adaptive_temporal_masks :
127
+ time_max_masks = int (lengths_cpu [idx ] * self .time_masks )
128
+ else :
129
+ time_max_masks = self .time_masks
130
+
117
131
# Derive time width, sometimes based percentage of input length.
118
132
if self .adaptive_temporal_width :
119
133
time_max_width = max (1 , int (lengths_cpu [idx ] * self .time_width ))
@@ -122,7 +136,7 @@ def _forward_legacy(self, input_spec, length):
122
136
time_start_upper_bound = max (1 , lengths_cpu [idx ] - time_max_width )
123
137
124
138
# Set time masking
125
- for _ in range (self . time_masks ):
139
+ for _ in range (time_max_masks ):
126
140
start = self ._rng .randint (0 , time_start_upper_bound )
127
141
width = self ._rng .randint (0 , time_max_width )
128
142
fill_mask [idx , :, start : start + width ] = True
0 commit comments