From 46200b31de7d906f2dbf1609e1f1470f91789be3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=BA=90=E6=96=87=E9=9B=A8?= <41315874+fumiama@users.noreply.github.com> Date: Tue, 2 Jul 2024 17:22:58 +0900 Subject: [PATCH] feat: add optional param `use_flash_attn` --- ChatTTS/core.py | 5 ++++- ChatTTS/model/gpt.py | 12 ++++++++---- README.md | 4 ++++ examples/ipynb/colab.ipynb | 13 ++++++------- examples/ipynb/example.ipynb | 8 ++++---- 5 files changed, 26 insertions(+), 16 deletions(-) diff --git a/ChatTTS/core.py b/ChatTTS/core.py index d73c9c26a..5fdb1c7ab 100644 --- a/ChatTTS/core.py +++ b/ChatTTS/core.py @@ -128,6 +128,7 @@ def load( custom_path: Optional[torch.serialization.FILE_LIKE] = None, device: Optional[torch.device] = None, coef: Optional[torch.Tensor] = None, + use_flash_attn=False, ) -> bool: download_path = self.download_models(source, force_redownload, custom_path) if download_path is None: @@ -136,6 +137,7 @@ def load( device=device, compile=compile, coef=coef, + use_flash_attn=use_flash_attn, **{ k: os.path.join(download_path, v) for k, v in OmegaConf.load( @@ -255,6 +257,7 @@ def _load( device: Optional[torch.device] = None, compile: bool = True, coef: Optional[str] = None, + use_flash_attn=False, ): if device is None: device = select_device() @@ -292,7 +295,7 @@ def _load( if gpt_config_path: cfg = OmegaConf.load(gpt_config_path) - gpt = GPT(**cfg, device=device, logger=self.logger).eval() + gpt = GPT(**cfg, use_flash_attn=use_flash_attn, device=device, logger=self.logger).eval() assert gpt_ckpt_path, "gpt_ckpt_path should not be None" gpt.load_state_dict(torch.load(gpt_ckpt_path, weights_only=True, mmap=True)) gpt.prepare(compile=compile and "cuda" in str(device)) diff --git a/ChatTTS/model/gpt.py b/ChatTTS/model/gpt.py index 4109adf59..e6f261163 100644 --- a/ChatTTS/model/gpt.py +++ b/ChatTTS/model/gpt.py @@ -32,6 +32,7 @@ def __init__( num_audio_tokens: int, num_text_tokens: int, num_vq=4, + use_flash_attn=False, device=torch.device("cpu"), logger=logging.getLogger(__name__), ): @@ -45,6 +46,8 @@ def __init__( self.num_vq = num_vq self.num_audio_tokens = num_audio_tokens + self.use_flash_attn = use_flash_attn + self.gpt = self._build_llama(gpt_config, self.device_gpt) self.model_dim = int(self.gpt.config.hidden_size) self.emb_code = nn.ModuleList( @@ -96,7 +99,7 @@ def get(self) -> bool: return self._interrupt def _build_llama( - self, config: omegaconf.DictConfig, device: torch.device + self, config: omegaconf.DictConfig, device: torch.device, ) -> LlamaModel: model = None @@ -114,11 +117,12 @@ def _build_llama( ) if model is None: - if is_flash_attn_2_available(): + if self.use_flash_attn and is_flash_attn_2_available(): llama_config = LlamaConfig( **config, attn_implementation="flash_attention_2", ) + self.logger.warn("enabling flash_attention_2 may make gpt be even slower") else: llama_config = LlamaConfig(**config) model = LlamaModel(llama_config) @@ -127,7 +131,7 @@ def _build_llama( return model.to(device) def prepare(self, compile=False): - if is_flash_attn_2_available(): + if self.use_flash_attn and is_flash_attn_2_available(): self.gpt = self.gpt.to(dtype=torch.float16) if compile: try: @@ -435,7 +439,7 @@ def generate( ) del_all(model_input) attentions.append(outputs.attentions) - hidden_states = outputs.last_hidden_state.to(self.device) # 🐻 + hidden_states = outputs.last_hidden_state.to(self.device, dtype=torch.float) # 🐻 past_key_values = outputs.past_key_values del_all(outputs) if return_hidden: diff --git a/README.md b/README.md index f879e3a52..c2ca90e08 100644 --- a/README.md +++ b/README.md @@ -107,6 +107,10 @@ pip install git+https://github.com/NVIDIA/TransformerEngine.git@stable > [!Note] > See supported devices at the [Hugging Face Doc](https://huggingface.co/docs/transformers/perf_infer_gpu_one#flashattention-2). +> [!Warning] +> Currently the FlashAttention-2 will slow down the generating speed according to [this issue](https://github.com/huggingface/transformers/issues/26990). +> Only install it on developing purpose. + ```bash pip install flash-attn --no-build-isolation ``` diff --git a/examples/ipynb/colab.ipynb b/examples/ipynb/colab.ipynb index 04633869d..a42de33e8 100644 --- a/examples/ipynb/colab.ipynb +++ b/examples/ipynb/colab.ipynb @@ -36,8 +36,7 @@ "outputs": [], "source": [ "!pip install -r /content/ChatTTS/requirements.txt\n", - "!ldconfig /usr/lib64-nvidia\n", - "!pip install flash-attn --no-build-isolation" + "!ldconfig /usr/lib64-nvidia" ] }, { @@ -116,7 +115,7 @@ "id": "3Ty427FZNH30" }, "source": [ - "### Here are three choices for loading models:" + "### Here are three choices for loading models," ] }, { @@ -125,7 +124,7 @@ "id": "NInF7Lk1NH30" }, "source": [ - "#### 1. Load models from Hugging Face:" + "#### 1. Load models from Hugging Face (recommend)" ] }, { @@ -137,7 +136,7 @@ "outputs": [], "source": [ "# use force_redownload=True if the weights have been updated.\n", - "chat.load(source=\"huggingface\", force_redownload=True)" + "chat.load(source=\"huggingface\")" ] }, { @@ -146,7 +145,7 @@ "id": "AhBD5WUPNH30" }, "source": [ - "#### 2. Load models from local directories 'asset' and 'config':" + "#### 2. Load models from local directories 'asset' and 'config'" ] }, { @@ -167,7 +166,7 @@ "id": "c0qjGPNkNH31" }, "source": [ - "#### 3. Load models from a custom path:" + "#### 3. Load models from a custom path" ] }, { diff --git a/examples/ipynb/example.ipynb b/examples/ipynb/example.ipynb index 2205a7ab2..167d02b69 100644 --- a/examples/ipynb/example.ipynb +++ b/examples/ipynb/example.ipynb @@ -79,14 +79,14 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "### Here are three choices for loading models:" + "### Here are three choices for loading models," ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "#### 1. Load models from Hugging Face:" + "#### 1. Load models from Hugging Face (not suitable in CN)" ] }, { @@ -103,7 +103,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "#### 2. Load models from local directories 'asset' and 'config':" + "#### 2. Load models from local directories 'asset' and 'config' (recommend)" ] }, { @@ -120,7 +120,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "#### 3. Load models from a custom path:" + "#### 3. Load models from a custom path" ] }, {