Skip to content

Commit

Permalink
Merge pull request #791 from PromtEngineer/llama3
Browse files Browse the repository at this point in the history
added support for llama3
  • Loading branch information
PromtEngineer authored May 3, 2024
2 parents e997a8a + c1f04b5 commit cf530d2
Show file tree
Hide file tree
Showing 5 changed files with 56 additions and 11 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -71,14 +71,14 @@ For `NVIDIA` GPUs support, use `cuBLAS`

```shell
# Example: cuBLAS
CMAKE_ARGS="-DLLAMA_CUBLAS=on" FORCE_CMAKE=1 pip install llama-cpp-python==0.1.83 --no-cache-dir
CMAKE_ARGS="-DLLAMA_CUBLAS=on" FORCE_CMAKE=1 pip install llama-cpp-python --no-cache-dir
```

For Apple Metal (`M1/M2`) support, use

```shell
# Example: METAL
CMAKE_ARGS="-DLLAMA_METAL=on" FORCE_CMAKE=1 pip install llama-cpp-python==0.1.83 --no-cache-dir
CMAKE_ARGS="-DLLAMA_METAL=on" FORCE_CMAKE=1 pip install llama-cpp-python --no-cache-dir
```
For more details, please refer to [llama-cpp](https://github.com/abetlen/llama-cpp-python#installation-with-openblas--cublas--clblast--metal)

Expand Down
17 changes: 14 additions & 3 deletions constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
)

# Context Window and Max New Tokens
CONTEXT_WINDOW_SIZE = 4096
CONTEXT_WINDOW_SIZE = 8096
MAX_NEW_TOKENS = CONTEXT_WINDOW_SIZE # int(CONTEXT_WINDOW_SIZE/4)

#### If you get a "not enough space in the buffer" error, you should reduce the values below, start with half of the original values and keep halving the value until the error stops appearing
Expand Down Expand Up @@ -100,8 +100,19 @@
# MODEL_ID = "TheBloke/Llama-2-13b-Chat-GGUF"
# MODEL_BASENAME = "llama-2-13b-chat.Q4_K_M.gguf"

MODEL_ID = "TheBloke/Llama-2-7b-Chat-GGUF"
MODEL_BASENAME = "llama-2-7b-chat.Q4_K_M.gguf"
# MODEL_ID = "TheBloke/Llama-2-7b-Chat-GGUF"
# MODEL_BASENAME = "llama-2-7b-chat.Q4_K_M.gguf"

# MODEL_ID = "QuantFactory/Meta-Llama-3-8B-Instruct-GGUF"
# MODEL_BASENAME = "Meta-Llama-3-8B-Instruct.Q4_K_M.gguf"

# LLAMA 3 # use for Apple Silicon
MODEL_ID = "meta-llama/Meta-Llama-3-8B-Instruct"
MODEL_BASENAME = None

# LLAMA 3 # use for NVIDIA GPUs
# MODEL_ID = "unsloth/llama-3-8b-bnb-4bit"
# MODEL_BASENAME = None

# MODEL_ID = "TheBloke/Mistral-7B-Instruct-v0.1-GGUF"
# MODEL_BASENAME = "mistral-7b-instruct-v0.1.Q8_0.gguf"
Expand Down
16 changes: 13 additions & 3 deletions load_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,9 +136,19 @@ def load_full_model(model_id, model_basename, device_type, logging):
"""

if device_type.lower() in ["mps", "cpu"]:
logging.info("Using LlamaTokenizer")
tokenizer = LlamaTokenizer.from_pretrained(model_id, cache_dir="./models/")
model = LlamaForCausalLM.from_pretrained(model_id, cache_dir="./models/")
logging.info("Using AutoModelForCausalLM")
# tokenizer = LlamaTokenizer.from_pretrained(model_id, cache_dir="./models/")
# model = LlamaForCausalLM.from_pretrained(model_id, cache_dir="./models/")

model = AutoModelForCausalLM.from_pretrained(model_id,
# quantization_config=quantization_config,
# low_cpu_mem_usage=True,
# torch_dtype="auto",
torch_dtype=torch.bfloat16,
device_map="auto",
cache_dir="./models/")

tokenizer = AutoTokenizer.from_pretrained(model_id, cache_dir="./models/")
else:
logging.info("Using AutoModelForCausalLM for full models")
tokenizer = AutoTokenizer.from_pretrained(model_id, cache_dir="./models/")
Expand Down
24 changes: 24 additions & 0 deletions prompt_template_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,28 @@ def get_prompt_template(system_prompt=system_prompt, promptTemplate_type=None, h

prompt_template = B_INST + SYSTEM_PROMPT + instruction + E_INST
prompt = PromptTemplate(input_variables=["context", "question"], template=prompt_template)

elif promptTemplate_type == "llama3":

B_INST, E_INST = "<|start_header_id|>user<|end_header_id|>", "<|eot_id|>"
B_SYS, E_SYS = "<|begin_of_text|><|start_header_id|>system<|end_header_id|> ", "<|eot_id|>"
ASSISTANT_INST = "<|start_header_id|>assistant<|end_header_id|>"
SYSTEM_PROMPT = B_SYS + system_prompt + E_SYS
if history:
instruction = """
Context: {history} \n {context}
User: {question}"""

prompt_template = SYSTEM_PROMPT + B_INST + instruction + ASSISTANT_INST
prompt = PromptTemplate(input_variables=["history", "context", "question"], template=prompt_template)
else:
instruction = """
Context: {context}
User: {question}"""

prompt_template = SYSTEM_PROMPT + B_INST + instruction + ASSISTANT_INST
prompt = PromptTemplate(input_variables=["context", "question"], template=prompt_template)

elif promptTemplate_type == "mistral":
B_INST, E_INST = "<s>[INST] ", " [/INST]"
if history:
Expand Down Expand Up @@ -82,6 +104,8 @@ def get_prompt_template(system_prompt=system_prompt, promptTemplate_type=None, h

memory = ConversationBufferMemory(input_key="question", memory_key="history")

print(f"Here is the prompt used: {prompt}")

return (
prompt,
memory,
Expand Down
6 changes: 3 additions & 3 deletions run_localGPT.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,11 +209,11 @@ def retrieval_qa_pipline(device_type, use_history, promptTemplate_type="llama"):
)
@click.option(
"--model_type",
default="llama",
default="llama3",
type=click.Choice(
["llama", "mistral", "non_llama"],
["llama3", "llama", "mistral", "non_llama"],
),
help="model type, llama, mistral or non_llama",
help="model type, llama3, llama, mistral or non_llama",
)
@click.option(
"--save_qa",
Expand Down

0 comments on commit cf530d2

Please sign in to comment.