|
1 | 1 | #!/usr/bin/env python3 |
2 | 2 | # -*- coding: utf-8 -*- |
3 | | -from functools import cache |
4 | | -from typing import List |
5 | | - |
6 | | -from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer, LlamaTokenizer |
7 | 3 |
|
| 4 | +import torch |
| 5 | +from typing import List |
| 6 | +from functools import cache |
| 7 | +from transformers import ( |
| 8 | + AutoModel, |
| 9 | + AutoModelForCausalLM, |
| 10 | + AutoTokenizer, |
| 11 | + LlamaTokenizer, |
| 12 | + BitsAndBytesConfig, |
| 13 | +) |
8 | 14 | from pilot.configs.model_config import DEVICE |
| 15 | +from pilot.configs.config import Config |
| 16 | + |
| 17 | +bnb_config = BitsAndBytesConfig( |
| 18 | + load_in_4bit=True, |
| 19 | + bnb_4bit_quant_type="nf4", |
| 20 | + bnb_4bit_compute_dtype="bfloat16", |
| 21 | + bnb_4bit_use_double_quant=False, |
| 22 | +) |
| 23 | +CFG = Config() |
9 | 24 |
|
10 | 25 |
|
11 | 26 | class BaseLLMAdaper: |
@@ -97,16 +112,44 @@ def loader(self, model_path: str, from_pretrained_kwargs: dict): |
97 | 112 | return model, tokenizer |
98 | 113 |
|
99 | 114 |
|
100 | | -class GuanacoAdapter(BaseLLMAdaper): |
| 115 | +class FalconAdapater(BaseLLMAdaper): |
| 116 | + """falcon Adapter""" |
| 117 | + |
| 118 | + def match(self, model_path: str): |
| 119 | + return "falcon" in model_path |
| 120 | + |
| 121 | + def loader(self, model_path: str, from_pretrained_kwagrs: dict): |
| 122 | + tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False) |
| 123 | + |
| 124 | + if CFG.QLoRA: |
| 125 | + model = AutoModelForCausalLM.from_pretrained( |
| 126 | + model_path, |
| 127 | + load_in_4bit=True, # quantize |
| 128 | + quantization_config=bnb_config, |
| 129 | + device_map={"": 0}, |
| 130 | + trust_remote_code=True, |
| 131 | + **from_pretrained_kwagrs, |
| 132 | + ) |
| 133 | + else: |
| 134 | + model = AutoModelForCausalLM.from_pretrained( |
| 135 | + model_path, |
| 136 | + trust_remote_code=True, |
| 137 | + device_map={"": 0}, |
| 138 | + **from_pretrained_kwagrs, |
| 139 | + ) |
| 140 | + return model, tokenizer |
| 141 | + |
| 142 | + |
| 143 | +class GorillaAdapter(BaseLLMAdaper): |
101 | 144 | """TODO Support guanaco""" |
102 | 145 |
|
103 | 146 | def match(self, model_path: str): |
104 | | - return "guanaco" in model_path |
| 147 | + return "gorilla" in model_path |
105 | 148 |
|
106 | 149 | def loader(self, model_path: str, from_pretrained_kwargs: dict): |
107 | | - tokenizer = LlamaTokenizer.from_pretrained(model_path) |
| 150 | + tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False) |
108 | 151 | model = AutoModelForCausalLM.from_pretrained( |
109 | | - model_path, load_in_4bit=True, device_map={"": 0}, **from_pretrained_kwargs |
| 152 | + model_path, low_cpu_mem_usage=True, **from_pretrained_kwargs |
110 | 153 | ) |
111 | 154 | return model, tokenizer |
112 | 155 |
|
@@ -166,6 +209,8 @@ def loader(self, model_path: str, from_pretrained_kwargs: dict): |
166 | 209 | register_llm_model_adapters(VicunaLLMAdapater) |
167 | 210 | register_llm_model_adapters(ChatGLMAdapater) |
168 | 211 | register_llm_model_adapters(GuanacoAdapter) |
| 212 | +register_llm_model_adapters(FalconAdapater) |
| 213 | +register_llm_model_adapters(GorillaAdapter) |
169 | 214 | # TODO Default support vicuna, other model need to tests and Evaluate |
170 | 215 |
|
171 | 216 | # just for test, remove this later |
|
0 commit comments