Skip to content

Commit

Permalink
feat(gpt): add flash_attention_2 support
Browse files Browse the repository at this point in the history
  • Loading branch information
fumiama committed Jul 2, 2024
1 parent 3fa30ce commit c109089
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 9 deletions.
10 changes: 6 additions & 4 deletions ChatTTS/model/gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from transformers import LlamaModel, LlamaConfig, LogitsWarper
from transformers.cache_utils import Cache
from transformers.modeling_outputs import BaseModelOutputWithPast
from transformers.utils import is_flash_attn_2_available

from .processors import CustomRepetitionPenaltyLogitsProcessorRepeat
from ..utils import del_all
Expand Down Expand Up @@ -97,20 +98,21 @@ def get(self) -> bool:
def _build_llama(
self, config: omegaconf.DictConfig, device: torch.device
) -> LlamaModel:
llama_config = LlamaConfig(**config)

model = None
if "cuda" in str(device) and platform.system().lower() == "linux":
try:
from .cuda import TELlamaModel

model = TELlamaModel(llama_config)
model = TELlamaModel(LlamaConfig(**config))
self.logger.info("Linux with CUDA, try NVIDIA accelerated TELlamaModel")
except Exception as e:
model = None
self.logger.warn(
f"use default LlamaModel for importing TELlamaModel error: {e}"
)
if is_flash_attn_2_available():
llama_config = LlamaConfig(**config, attn_implementation="flash_attention_2")
else:
llama_config = LlamaConfig(**config)
if model is None:
model = LlamaModel(llama_config)
del model.embed_tokens
Expand Down
20 changes: 16 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,10 @@ ChatTTS is a text-to-speech model designed specifically for dialogue scenarios s
- The open-source version on **[HuggingFace](https://huggingface.co/2Noise/ChatTTS)** is a 40,000 hours pre-trained model without SFT.

### Roadmap
- [x] Open-source the 40k hour base model and spk_stats file
- [x] Streaming audio generation without refining the text*
- [ ] Open-source the 40k hour version with multi-emotion control
- [ ] ChatTTS.cpp maybe? (PR or new repo are welcomed.)
- [x] Open-source the 40k hour base model and spk_stats file.
- [x] Streaming audio generation.
- [ ] Open-source the 40k hour version with multi-emotion control.
- [ ] ChatTTS.cpp (new repo in `2noise` org is welcomed)

### Disclaimer
> [!Important]
Expand Down Expand Up @@ -95,10 +95,22 @@ pip install -r requirements.txt
> [!Note]
> The installation process is very slow.
> [!Warning]
> The TransformerEngine adaption is currently developing and CANNOT run properly now.
> Only install it in developing purpose.
```bash
pip install git+https://github.com/NVIDIA/TransformerEngine.git@stable
```

#### Optional: Install FlashAttention-2 (mainly NVIDIA GPU)
> [!Note]
> See supported devices at the [Hugging Face Doc](https://huggingface.co/docs/transformers/perf_infer_gpu_one#flashattention-2).
```bash
pip install flash-attn --no-build-isolation
```

### Quick Start
> Make sure you are under the project root directory when you execute these commands below.
Expand Down
3 changes: 2 additions & 1 deletion examples/ipynb/colab.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@
"outputs": [],
"source": [
"!pip install -r /content/ChatTTS/requirements.txt\n",
"!ldconfig /usr/lib64-nvidia"
"!ldconfig /usr/lib64-nvidia\n",
"!pip install flash-attn --no-build-isolation"
]
},
{
Expand Down

0 comments on commit c109089

Please sign in to comment.