Skip to content

Commit 78e6c08

Browse files
committed
Fix long text generation.
1 parent 77071da commit 78e6c08

File tree

3 files changed

+6
-6
lines changed

3 files changed

+6
-6
lines changed

f5_tts_mlx/duration.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,10 @@ def __call__(
183183
text = list_str_to_tensor(text)
184184
assert text.shape[0] == batch
185185

186+
if seq_len < text.shape[1]:
187+
seq_len = text.shape[1]
188+
inp = mx.pad(inp, [(0, 0), (0, seq_len - inp.shape[1]), (0, 0)])
189+
186190
# lens and mask
187191
if not exists(lens):
188192
lens = mx.full((batch,), seq_len)

f5_tts_mlx/utils.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -89,18 +89,14 @@ def maybe_masked_mean(t: mx.array, mask: mx.array | None = None) -> mx.array:
8989
return einx.divide("b d, b -> b d", num, mx.maximum(den, 1))
9090

9191

92-
def pad_to_length(t: mx.array, length: int, value=None):
92+
def pad_to_length(t: mx.array, length: int, value=0):
9393
ndim = t.ndim
9494
seq_len = t.shape[-1]
9595
if length > seq_len:
9696
if ndim == 1:
9797
t = mx.pad(t, [(0, length - seq_len)], constant_values=value)
9898
elif ndim == 2:
9999
t = mx.pad(t, [(0, 0), (0, length - seq_len)], constant_values=value)
100-
elif ndim == 3:
101-
t = mx.pad(
102-
t, [(0, 0), (0, length - seq_len), (0, 0)], constant_values=value
103-
)
104100
else:
105101
raise ValueError(f"Unsupported padding dims: {ndim}")
106102
return t[..., :length]

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ build-backend = "setuptools.build_meta"
66

77
[project]
88
name = "f5-tts-mlx"
9-
version = "0.1.5"
9+
version = "0.1.6"
1010
authors = [{name = "Lucas Newman", email = "[email protected]"}]
1111
license = {text = "MIT"}
1212
description = "F5-TTS - MLX"

0 commit comments

Comments
 (0)