Skip to content

Commit

Permalink
[Model] Add support for xverse (ggerganov#6301)
Browse files Browse the repository at this point in the history
* Support xverse model convert to gguf format.

* 1. Convert xverse models to gguf;
2. Add LLM_ARCH_XVERSE inference in llama.cpp;
3. Add xverse item in Supported models in README.md;

* * gguf-py: remove redundant logs
* llama: remove the init_mapping_prefetch custom parameter

* llama.cpp: Include the changes from ggerganov#6122 to exclude the unused outputs of the last layers.

* - Fix format issues
- Remove duplicate set kqv_out to llm_build_kv

* Update llama.cpp

---------

Co-authored-by: willhe <[email protected]>
Co-authored-by: willhe <[email protected]>
  • Loading branch information
3 people authored and hodlen committed Apr 3, 2024
1 parent c33eeba commit d70d531
Show file tree
Hide file tree
Showing 4 changed files with 329 additions and 1 deletion.
1 change: 1 addition & 0 deletions README.md
Expand Up @@ -115,6 +115,7 @@ Typically finetunes of the base models below are supported as well.
- [x] [CodeShell](https://github.com/WisdomShell/codeshell)
- [x] [Gemma](https://ai.google.dev/gemma)
- [x] [Mamba](https://github.com/state-spaces/mamba)
- [x] [Xverse](https://huggingface.co/models?search=xverse)
- [x] [Command-R](https://huggingface.co/CohereForAI/c4ai-command-r-v01)

**Multimodal models:**
Expand Down
142 changes: 142 additions & 0 deletions convert-hf-to-gguf.py
Expand Up @@ -773,6 +773,148 @@ def _reverse_hf_part(self, weights: Tensor, n_part: int) -> Tensor:
return weights[r * n_part:r * n_part + r, ...]


@Model.register("XverseForCausalLM")
class XverseModel(Model):
model_arch = gguf.MODEL_ARCH.XVERSE

def set_vocab(self):
assert (self.dir_model / "tokenizer.json").is_file()
dir_model = self.dir_model
hparams = self.hparams

tokens: list[bytearray] = []
toktypes: list[int] = []

from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(dir_model)
vocab_size = hparams.get("vocab_size", len(tokenizer.vocab))
assert max(tokenizer.vocab.values()) < vocab_size

reverse_vocab = {id_: encoded_tok for encoded_tok, id_ in tokenizer.vocab.items()}
added_vocab = tokenizer.get_added_vocab()

for token_id in range(vocab_size):
token_text = reverse_vocab[token_id].encode('utf-8')
# replace "\x00" to string with length > 0
if token_text == b"\x00":
toktype = gguf.TokenType.BYTE # special
token_text = f"<{token_text}>".encode('utf-8')
elif re.fullmatch(br"<0x[0-9A-Fa-f]{2}>", token_text):
toktype = gguf.TokenType.BYTE # special
elif reverse_vocab[token_id] in added_vocab:
if tokenizer.added_tokens_decoder[token_id].special:
toktype = gguf.TokenType.CONTROL
else:
toktype = gguf.TokenType.USER_DEFINED
else:
toktype = gguf.TokenType.NORMAL

tokens.append(token_text)
toktypes.append(toktype)

self.gguf_writer.add_tokenizer_model("llama")
self.gguf_writer.add_token_list(tokens)
self.gguf_writer.add_token_types(toktypes)

special_vocab = gguf.SpecialVocab(dir_model, n_vocab=len(tokens))
special_vocab.add_to_gguf(self.gguf_writer)

def set_gguf_parameters(self):
block_count = self.hparams["num_hidden_layers"]
head_count = self.hparams["num_attention_heads"]
head_count_kv = self.hparams.get("num_key_value_heads", head_count)
hf_repo = self.hparams.get("_name_or_path", "")

ctx_length = 0
if "max_sequence_length" in self.hparams:
ctx_length = self.hparams["max_sequence_length"]
elif "max_position_embeddings" in self.hparams:
ctx_length = self.hparams["max_position_embeddings"]
elif "model_max_length" in self.hparams:
ctx_length = self.hparams["model_max_length"]
else:
print("gguf: can not find ctx length parameter.")
sys.exit()

self.gguf_writer.add_name(self.dir_model.name)
self.gguf_writer.add_source_hf_repo(hf_repo)
self.gguf_writer.add_tensor_data_layout("Meta AI original pth")
self.gguf_writer.add_context_length(ctx_length)
self.gguf_writer.add_embedding_length(self.hparams["hidden_size"])
self.gguf_writer.add_block_count(block_count)
self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"])
self.gguf_writer.add_rope_dimension_count(self.hparams["hidden_size"] // self.hparams["num_attention_heads"])
self.gguf_writer.add_head_count(head_count)
self.gguf_writer.add_head_count_kv(head_count_kv)
self.gguf_writer.add_layer_norm_rms_eps(self.hparams["rms_norm_eps"])

if self.hparams.get("rope_scaling") is not None and "factor" in self.hparams["rope_scaling"]:
if self.hparams["rope_scaling"].get("type") == "linear":
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR)
self.gguf_writer.add_rope_scaling_factor(self.hparams["rope_scaling"]["factor"])

def write_tensors(self):
# Collect tensors from generator object
model_kv = dict(self.get_tensors())
block_count = self.hparams["num_hidden_layers"]
head_count = self.hparams["num_attention_heads"]
tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count)
head_count_kv = self.hparams.get("num_key_value_heads", head_count)

for name, data_torch in model_kv.items():
# we don't need these
if name.endswith(".rotary_emb.inv_freq"):
continue

old_dtype = data_torch.dtype

# convert any unsupported data types to float32
if data_torch.dtype not in (torch.float16, torch.float32):
data_torch = data_torch.to(torch.float32)

# HF models permute some of the tensors, so we need to undo that
if name.endswith(("q_proj.weight")):
data_torch = self._reverse_hf_permute(data_torch, head_count, head_count)
if name.endswith(("k_proj.weight")):
data_torch = self._reverse_hf_permute(data_torch, head_count, head_count_kv)

data = data_torch.squeeze().numpy()

# map tensor names
new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias"))
if new_name is None:
print(f"Can not map tensor {name!r}")
sys.exit()

n_dims = len(data.shape)
data_dtype = data.dtype

# if f32 desired, convert any float16 to float32
if self.ftype == 0 and data_dtype == np.float16:
data = data.astype(np.float32)

# TODO: Why cant we use these float16 as-is? There should be not reason to store float16 as float32
if self.ftype == 1 and data_dtype == np.float16 and n_dims == 1:
data = data.astype(np.float32)

# if f16 desired, convert any float32 2-dim weight tensors to float16
if self.ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2:
data = data.astype(np.float16)

print(f"{name} -> {new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}")
self.gguf_writer.add_tensor(new_name, data)

def _reverse_hf_permute(self, weights: Tensor, n_head: int, n_kv_head: int | None = None) -> Tensor:
if n_kv_head is not None and n_head != n_kv_head:
n_head //= n_kv_head

return (
weights.reshape(n_head, 2, weights.shape[0] // n_head // 2, *weights.shape[1:])
.swapaxes(1, 2)
.reshape(weights.shape)
)


@Model.register("FalconForCausalLM", "RWForCausalLM")
class FalconModel(Model):
model_arch = gguf.MODEL_ARCH.FALCON
Expand Down
22 changes: 22 additions & 0 deletions gguf-py/gguf/constants.py
Expand Up @@ -123,6 +123,7 @@ class MODEL_ARCH(IntEnum):
GEMMA = auto()
STARCODER2 = auto()
MAMBA = auto()
XVERSE = auto()
COMMAND_R = auto()


Expand Down Expand Up @@ -191,6 +192,7 @@ class MODEL_TENSOR(IntEnum):
MODEL_ARCH.GEMMA: "gemma",
MODEL_ARCH.STARCODER2: "starcoder2",
MODEL_ARCH.MAMBA: "mamba",
MODEL_ARCH.XVERSE: "xverse",
MODEL_ARCH.COMMAND_R: "command-r",
}

Expand Down Expand Up @@ -606,6 +608,22 @@ class MODEL_TENSOR(IntEnum):
MODEL_TENSOR.SSM_D,
MODEL_TENSOR.SSM_OUT,
],
MODEL_ARCH.XVERSE: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,
MODEL_TENSOR.OUTPUT,
MODEL_TENSOR.ROPE_FREQS,
MODEL_TENSOR.ATTN_NORM,
MODEL_TENSOR.ATTN_Q,
MODEL_TENSOR.ATTN_K,
MODEL_TENSOR.ATTN_V,
MODEL_TENSOR.ATTN_OUT,
MODEL_TENSOR.ATTN_ROT_EMBD,
MODEL_TENSOR.FFN_NORM,
MODEL_TENSOR.FFN_GATE,
MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP,
],
MODEL_ARCH.COMMAND_R: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,
Expand Down Expand Up @@ -650,6 +668,10 @@ class MODEL_TENSOR(IntEnum):
MODEL_TENSOR.ROPE_FREQS,
MODEL_TENSOR.ATTN_ROT_EMBD,
],
MODEL_ARCH.XVERSE: [
MODEL_TENSOR.ROPE_FREQS,
MODEL_TENSOR.ATTN_ROT_EMBD,
],
}

#
Expand Down

0 comments on commit d70d531

Please sign in to comment.