Skip to content

Commit 9be7e8e

Browse files
committed
Add duration heuristic and make duration optional.
1 parent ff8d7f4 commit 9be7e8e

File tree

5 files changed

+47
-8
lines changed

5 files changed

+47
-8
lines changed

README.md

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,7 @@ pip install f5-tts-mlx
2020

2121
```bash
2222
python -m f5_tts_mlx.generate \
23-
--text "The quick brown fox jumped over the lazy dog." \
24-
--duration 3.5
23+
--text "The quick brown fox jumped over the lazy dog."
2524
```
2625

2726
See [examples/generate.py](./examples) for more options.

examples/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
To run the script, use the following format:
55

66
```bash
7-
python generate.py --text "Your input text here" --duration 10
7+
python generate.py --text "Your input text here"
88
```
99

1010
## Required Parameters

examples/generate.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import argparse
22
import datetime
33
import pkgutil
4+
import re
45
from typing import Optional
56

67
import mlx.core as mx
@@ -22,12 +23,13 @@
2223

2324
def generate(
2425
generation_text: str,
25-
duration: float,
26+
duration: Optional[float] = None,
2627
model_name: str = "lucasnewman/f5-tts-mlx",
2728
ref_audio_path: Optional[str] = None,
2829
ref_audio_text: Optional[str] = None,
2930
cfg_strength: float = 2.0,
3031
sway_sampling_coef: float = -1.0,
32+
speed: float = 1.0, # used when duration is None as part of the duration heuristic
3133
seed: Optional[int] = None,
3234
output_path: str = "output.wav",
3335
):
@@ -52,13 +54,24 @@ def generate(
5254

5355
audio = mx.array(audio)
5456
ref_audio_duration = audio.shape[0] / SAMPLE_RATE
57+
print(f"Got reference audio with duration: {ref_audio_duration:.2f} seconds")
5558

5659
rms = mx.sqrt(mx.mean(mx.square(audio)))
5760
if rms < TARGET_RMS:
5861
audio = audio * TARGET_RMS / rms
5962

6063
# generate the audio for the given text
6164
text = convert_char_to_pinyin([ref_audio_text + " " + generation_text])
65+
66+
# use a heuristic to determine the duration if not provided
67+
if duration is None:
68+
ref_audio_len = audio.shape[0] // HOP_LENGTH
69+
zh_pause_punc = r"。,、;:?!"
70+
ref_text_len = len(ref_audio_text.encode('utf-8')) + 3 * len(re.findall(zh_pause_punc, ref_audio_text))
71+
gen_text_len = len(generation_text.encode('utf-8')) + 3 * len(re.findall(zh_pause_punc, generation_text))
72+
duration_in_frames = ref_audio_len + int(ref_audio_len / ref_text_len * gen_text_len / speed)
73+
duration = (duration_in_frames / FRAMES_PER_SEC) - ref_audio_duration
74+
print(f"Using duration of {duration:.2f} seconds for generated speech.")
6275

6376
frame_duration = int((ref_audio_duration + duration) * FRAMES_PER_SEC)
6477
print(f"Generating {frame_duration} total frames of audio...")
@@ -104,7 +117,7 @@ def generate(
104117
parser.add_argument(
105118
"--duration",
106119
type=float,
107-
required=True,
120+
default=None,
108121
help="Duration of the generated audio in seconds",
109122
)
110123
parser.add_argument(
@@ -137,6 +150,12 @@ def generate(
137150
default=-1.0,
138151
help="Coefficient for sway sampling",
139152
)
153+
parser.add_argument(
154+
"--speed",
155+
type=float,
156+
default=1.0,
157+
help="Speed factor for the duration heuristic",
158+
)
140159
parser.add_argument(
141160
"--seed",
142161
type=int,
@@ -154,6 +173,7 @@ def generate(
154173
ref_audio_text=args.ref_text,
155174
cfg_strength=args.cfg,
156175
sway_sampling_coef=args.sway_coef,
176+
speed=args.speed,
157177
seed=args.seed,
158178
output_path=args.output,
159179
)

f5_tts_mlx/generate.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import argparse
22
import datetime
33
import pkgutil
4+
import re
45
from typing import Optional
56

67
import mlx.core as mx
@@ -22,12 +23,13 @@
2223

2324
def generate(
2425
generation_text: str,
25-
duration: float,
26+
duration: Optional[float] = None,
2627
model_name: str = "lucasnewman/f5-tts-mlx",
2728
ref_audio_path: Optional[str] = None,
2829
ref_audio_text: Optional[str] = None,
2930
cfg_strength: float = 2.0,
3031
sway_sampling_coef: float = -1.0,
32+
speed: float = 1.0, # used when duration is None as part of the duration heuristic
3133
seed: Optional[int] = None,
3234
output_path: str = "output.wav",
3335
):
@@ -52,13 +54,24 @@ def generate(
5254

5355
audio = mx.array(audio)
5456
ref_audio_duration = audio.shape[0] / SAMPLE_RATE
57+
print(f"Got reference audio with duration: {ref_audio_duration:.2f} seconds")
5558

5659
rms = mx.sqrt(mx.mean(mx.square(audio)))
5760
if rms < TARGET_RMS:
5861
audio = audio * TARGET_RMS / rms
5962

6063
# generate the audio for the given text
6164
text = convert_char_to_pinyin([ref_audio_text + " " + generation_text])
65+
66+
# use a heuristic to determine the duration if not provided
67+
if duration is None:
68+
ref_audio_len = audio.shape[0] // HOP_LENGTH
69+
zh_pause_punc = r"。,、;:?!"
70+
ref_text_len = len(ref_audio_text.encode('utf-8')) + 3 * len(re.findall(zh_pause_punc, ref_audio_text))
71+
gen_text_len = len(generation_text.encode('utf-8')) + 3 * len(re.findall(zh_pause_punc, generation_text))
72+
duration_in_frames = ref_audio_len + int(ref_audio_len / ref_text_len * gen_text_len / speed)
73+
duration = (duration_in_frames / FRAMES_PER_SEC) - ref_audio_duration
74+
print(f"Using duration of {duration:.2f} seconds for generated speech.")
6275

6376
frame_duration = int((ref_audio_duration + duration) * FRAMES_PER_SEC)
6477
print(f"Generating {frame_duration} total frames of audio...")
@@ -104,7 +117,7 @@ def generate(
104117
parser.add_argument(
105118
"--duration",
106119
type=float,
107-
required=True,
120+
default=None,
108121
help="Duration of the generated audio in seconds",
109122
)
110123
parser.add_argument(
@@ -137,6 +150,12 @@ def generate(
137150
default=-1.0,
138151
help="Coefficient for sway sampling",
139152
)
153+
parser.add_argument(
154+
"--speed",
155+
type=float,
156+
default=1.0,
157+
help="Speed factor for the duration heuristic",
158+
)
140159
parser.add_argument(
141160
"--seed",
142161
type=int,
@@ -154,6 +173,7 @@ def generate(
154173
ref_audio_text=args.ref_text,
155174
cfg_strength=args.cfg,
156175
sway_sampling_coef=args.sway_coef,
176+
speed=args.speed,
157177
seed=args.seed,
158178
output_path=args.output,
159179
)

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ build-backend = "setuptools.build_meta"
66

77
[project]
88
name = "f5-tts-mlx"
9-
version = "0.0.9"
9+
version = "0.1.0"
1010
authors = [{name = "Lucas Newman", email = "[email protected]"}]
1111
license = {text = "MIT"}
1212
description = "F5-TTS - MLX"

0 commit comments

Comments
 (0)