Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
144 changes: 51 additions & 93 deletions examples/quantize_and_pack_int4.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -25,8 +25,7 @@
"from compressed_tensors.quantization import (\n",
" QuantizationConfig,\n",
" QuantizationStatus,\n",
" apply_quantization_config,\n",
" compress_quantized_weights\n",
" apply_quantization_config\n",
")\n",
"from compressed_tensors.compressors import ModelCompressor\n",
"from transformers import AutoModelForCausalLM, AutoTokenizer, DefaultDataCollator\n",
Expand All @@ -37,51 +36,9 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 13,
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "c883cdc8ecd04866bd01d61796b81c26",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"config.json: 0%| | 0.00/560 [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "32b18b14b6774ce7b61d2854a1ed5f49",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"model.safetensors: 0%| | 0.00/4.40G [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "370c6d18521a4b65833a411728be1ed7",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"generation_config.json: 0%| | 0.00/129 [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
Expand Down Expand Up @@ -113,7 +70,7 @@
")"
]
},
"execution_count": 3,
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
Expand All @@ -122,7 +79,7 @@
"# load a dense, unquantized tiny llama model\n",
"device = \"cuda:0\"\n",
"model_name = \"TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T\"\n",
"model = AutoModelForCausalLM.from_pretrained(model_name, device_map=device, torch_dtype=\"auto\")\n",
"model = AutoModelForCausalLM.from_pretrained(model_name, device_map=device, torch_dtype=torch.bfloat16)\n",
"model"
]
},
Expand All @@ -139,7 +96,7 @@
},
{
"cell_type": "code",
"execution_count": 23,
"execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -164,7 +121,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 15,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -177,7 +134,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 16,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -198,14 +155,14 @@
},
{
"cell_type": "code",
"execution_count": 28,
"execution_count": 17,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Running calibration: 512it [00:33, 15.42it/s]\n"
"Running calibration: 512it [00:58, 8.82it/s]\n"
]
}
],
Expand Down Expand Up @@ -233,20 +190,24 @@
"\n",
"Notice that at this point, the weight itself is still a floating point and has not been quantized. \n",
"\n",
"To convert the weights to an integer type, we need to apply the `compress_quantized_weights` function. After compressing the weights, a forward pass of the model can no longer be run in PyTorch"
"To convert the weights to an integer type, we need to apply the `compress_model` function. After compressing the weights, a forward pass of the model can no longer be run in PyTorch.\n",
"\n",
"After compressing the quantized model with the `pack-quantized` format, weights are represented as logical int4 values packed into int32 containers ( `weight_packed` ), with the original shape recorded in `weight_shape`.\n",
"\n",
"This packed representation is what gets saved to disk when using ModelCompressor.compress_model(model)."
]
},
{
"cell_type": "code",
"execution_count": 29,
"execution_count": 18,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Scale: tensor([17296.], device='cuda:4', dtype=torch.float16), Zero Point: tensor([0], device='cuda:4', dtype=torch.int8)\n",
"Weight min: -1.587890625 max: 1.0283203125 dtype: torch.float16\n"
"Scale: tensor([-3.0465e+26], device='cuda:0', dtype=torch.bfloat16), Zero Point: tensor([0], device='cuda:0', dtype=torch.int8)\n",
"Weight min: -1.5859375 max: 1.03125 dtype: torch.bfloat16\n"
]
}
],
Expand All @@ -262,64 +223,62 @@
},
{
"cell_type": "code",
"execution_count": 30,
"execution_count": 19,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Compressing model: 154it [00:02, 59.75it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Scale: tensor([17296.], device='cuda:4', dtype=torch.float16), Zero Point: tensor([0], device='cuda:4', dtype=torch.int8)\n",
"Weight min: 0 max: 0 dtype: torch.int8\n"
"Compressed weight scale: tensor([-3.0465e+26], device='cuda:0', dtype=torch.bfloat16), zero point: tensor([0], device='cuda:0', dtype=torch.int8)\n",
"Compressed weight dtype: torch.int32\n",
"Compressed weight shape: torch.Size([2048, 256])\n",
"Uncompressed weight shape: tensor([2048, 2048], device='cuda:0')\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n"
]
}
],
"source": [
"# convert quantized weights to integers\n",
"model.apply(compress_quantized_weights)\n",
"compressor = ModelCompressor(quantization_config=config)\n",
"compressor.compress_model(model)\n",
"\n",
"state_dict = model.state_dict()\n",
"example_layer = \"model.layers.0.self_attn.q_proj.weight\"\n",
"scale = state_dict[example_layer + \"_scale\"]\n",
"zero_point = state_dict[example_layer + \"_zero_point\"]\n",
"weight = state_dict[example_layer]\n",
"print(f\"Scale: {scale}, Zero Point: {zero_point}\")\n",
"print(f\"Weight min: {torch.min(weight)} max: {torch.max(weight)} dtype: {weight.dtype}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"After compressing the quantized model, the weight matrix has a range of int4 but is stored in an int8. \n",
"\n",
"We can further compress the model on disk using the `pack-quantized` format we specified in the config. This compression format will pack the int4 weights into int32"
"weight = state_dict[example_layer + \"_packed\"]\n",
"shape = state_dict[example_layer + \"_shape\"]\n",
"print(f\"Compressed weight scale: {scale}, zero point: {zero_point}\")\n",
"print(f\"Compressed weight dtype: {weight.dtype}\")\n",
"print(f\"Compressed weight shape: {weight.shape}\")\n",
"print(f\"Uncompressed weight shape: {shape}\")"
]
},
{
"cell_type": "code",
"execution_count": 31,
"execution_count": 20,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Compression format: pack-quantized\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Quantized Compression: 100%|██████████| 509/509 [00:03<00:00, 153.70it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Size of the model's weights on disk using safetensors: 712.23 MB\n"
"Compression format: pack-quantized\n",
"Size of the model's weights on disk using safetensors: 712.25 MB\n"
]
}
],
Expand All @@ -330,9 +289,8 @@
"compression_format = config.format\n",
"print(f\"Compression format: {compression_format}\")\n",
"\n",
"compressor = ModelCompressor(quantization_config=config)\n",
"compressed_state_dict = compressor.compress(model)\n",
"model.save_pretrained(output_dir, state_dict=compressed_state_dict)\n",
"\n",
"model.save_pretrained(output_dir, state_dict=model.state_dict())\n",
"compressor.update_config(output_dir)\n",
"\n",
"compressed_size_on_disk_mb = os.path.getsize(os.path.join(output_dir, \"model.safetensors\")) / 1024 / 1024\n",
Expand All @@ -356,7 +314,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
"version": "3.12.12"
}
},
"nbformat": 4,
Expand Down