Skip to content

Commit 00da499

Browse files
committed
Add option to convert weights from the pytorch format.
1 parent 52c940a commit 00da499

File tree

1 file changed

+36
-1
lines changed

1 file changed

+36
-1
lines changed

f5_tts_mlx/cfm.py

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
from __future__ import annotations
1111
from datetime import datetime
12+
import os
1213
from pathlib import Path
1314
from random import random
1415
from typing import Callable, Literal
@@ -377,7 +378,7 @@ def fn(t, x):
377378
return out, trajectory
378379

379380
@classmethod
380-
def from_pretrained(cls, hf_model_name_or_path: str) -> F5TTS:
381+
def from_pretrained(cls, hf_model_name_or_path: str, convert_weights = False) -> F5TTS:
381382
path = fetch_from_hub(hf_model_name_or_path)
382383

383384
if path is None:
@@ -435,6 +436,40 @@ def from_pretrained(cls, hf_model_name_or_path: str) -> F5TTS:
435436
)
436437

437438
weights = mx.load(model_path.as_posix(), format="safetensors")
439+
440+
if convert_weights:
441+
new_weights = {}
442+
for k, v in weights.items():
443+
k = k.replace('ema_model.', '')
444+
445+
# rename layers
446+
if len(k) < 1 or 'mel_spec.' in k or k in ('initted', 'step'):
447+
continue
448+
elif '.to_out' in k:
449+
k = k.replace('.to_out', '.to_out.layers')
450+
elif '.text_blocks' in k:
451+
k = k.replace('.text_blocks', '.text_blocks.layers')
452+
elif '.ff.ff.0.0' in k:
453+
k = k.replace('.ff.ff.0.0', '.ff.ff.layers.0.layers.0')
454+
elif '.ff.ff.2' in k:
455+
k = k.replace('.ff.ff.2', '.ff.ff.layers.2')
456+
elif '.time_mlp' in k:
457+
k = k.replace('.time_mlp', '.time_mlp.layers')
458+
elif '.conv1d' in k:
459+
k = k.replace('.conv1d', '.conv1d.layers')
460+
461+
# reshape weights
462+
if '.dwconv.weight' in k:
463+
v = v.swapaxes(1, 2)
464+
elif '.conv1d.layers.0.weight' in k:
465+
v = v.swapaxes(1, 2)
466+
elif '.conv1d.layers.2.weight' in k:
467+
v = v.swapaxes(1, 2)
468+
469+
new_weights[k] = v
470+
471+
weights = new_weights
472+
438473
f5tts.load_weights(list(weights.items()))
439474
mx.eval(f5tts.parameters())
440475

0 commit comments

Comments
 (0)