diff --git a/ChatTTS/model/gpt.py b/ChatTTS/model/gpt.py index 187517796..1bbfdbcd1 100644 --- a/ChatTTS/model/gpt.py +++ b/ChatTTS/model/gpt.py @@ -1,4 +1,4 @@ -import os +import os, platform os.environ["TOKENIZERS_PARALLELISM"] = "false" """ @@ -24,21 +24,6 @@ from ..utils import del_all -"""class LlamaMLP(nn.Module): - def __init__(self, hidden_size, intermediate_size): - super().__init__() - self.hidden_size = hidden_size - self.intermediate_size = intermediate_size - self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) - self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) - self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) - self.act_fn = F.silu - - def forward(self, x): - down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) - return down_proj""" - - class GPT(nn.Module): def __init__( self, @@ -115,12 +100,12 @@ def _build_llama( llama_config = LlamaConfig(**config) model = None - if "cuda" in str(device): + if "cuda" in str(device) and platform.system().lower() == "linux": try: from .cuda import TELlamaModel model = TELlamaModel(llama_config) - self.logger.info("use NVIDIA accelerated TELlamaModel") + self.logger.info("Linux with CUDA, try NVIDIA accelerated TELlamaModel") except Exception as e: model = None self.logger.warn( diff --git a/README.md b/README.md index bb54dd663..be8307363 100644 --- a/README.md +++ b/README.md @@ -92,7 +92,7 @@ conda activate chattts pip install -r requirements.txt ``` -#### Optional: Install TransformerEngine if using NVIDIA GPU +#### Optional: Install TransformerEngine if using NVIDIA GPU (Linux only) > [!Note] > The installation process is very slow.