Skip to content

Commit

Permalink
Merge branch 'dev' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
fumiama authored Jul 29, 2024
2 parents fd55a09 + 680e046 commit 8bd6b17
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 2 deletions.
1 change: 1 addition & 0 deletions ChatTTS/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,7 @@ def _load(
use_flash_attn=use_flash_attn,
use_vllm=use_vllm,
device=device,
device_gpt=self.device_gpt,
logger=self.logger,
).eval()
assert gpt_ckpt_path, "gpt_ckpt_path should not be None"
Expand Down
4 changes: 2 additions & 2 deletions ChatTTS/model/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,12 +180,12 @@ def _decode_prompt(prompt: str) -> torch.Tensor:
dtype="<u2",
).copy()
del dec
return torch.from_numpy(p).view(*shp)
return torch.from_numpy(p.astype(np.int32)).view(*shp)

@staticmethod
@torch.no_grad()
def _encode_prompt(prompt: torch.Tensor) -> str:
arr: np.ndarray = prompt.to(dtype=torch.uint16, device="cpu").numpy()
arr: np.ndarray = prompt.cpu().numpy().astype(np.uint16)
shp = arr.shape
assert len(shp) == 2, "prompt must be a 2D tensor"
s = b14.encode_to_string(
Expand Down

0 comments on commit 8bd6b17

Please sign in to comment.